Enhance crawler with safety features and site-specific report directories
This commit is contained in:
		
							parent
							
								
									f2ae09dd78
								
							
						
					
					
						commit
						1cc705c5c7
					
				|  | @ -148,6 +148,198 @@ func Crawl(ctx context.Context, startURL string, maxDepth int, concurrency int, | ||||||
| 	return visited, errs, outlinks, pageInfos | 	return visited, errs, outlinks, pageInfos | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CrawlWithSafety is a safer version of the crawler that adds a global timeout
 | ||||||
|  | // and limits the maximum number of URLs to process to prevent infinite loops.
 | ||||||
|  | func CrawlWithSafety(ctx context.Context, startURL string, maxDepth int, concurrency int,  | ||||||
|  | 	sameHostOnly bool, client *http.Client, userAgent string,  | ||||||
|  | 	visitedCallback func(string, int, int), errorCallback func(string, error, int), | ||||||
|  | 	maxURLs int) (map[string]struct{}, map[string]error, map[string]map[string]struct{}, map[string]PageInfo) { | ||||||
|  | 	 | ||||||
|  | 	visited := make(map[string]struct{}) | ||||||
|  | 	errs := make(map[string]error) | ||||||
|  | 	outlinks := make(map[string]map[string]struct{}) | ||||||
|  | 	pageInfos := make(map[string]PageInfo) | ||||||
|  | 	var mu sync.Mutex | ||||||
|  | 	var urlCounter int | ||||||
|  | 
 | ||||||
|  | 	origin := urlutil.Origin(startURL) | ||||||
|  | 
 | ||||||
|  | 	tasks := make(chan task, concurrency*2) | ||||||
|  | 	wgWorkers := sync.WaitGroup{} | ||||||
|  | 	wgTasks := sync.WaitGroup{} | ||||||
|  | 
 | ||||||
|  | 	enqueue := func(t task) { | ||||||
|  | 		mu.Lock() | ||||||
|  | 		defer mu.Unlock() | ||||||
|  | 		 | ||||||
|  | 		// Check if we've reached the URL limit
 | ||||||
|  | 		if urlCounter >= maxURLs { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		 | ||||||
|  | 		urlCounter++ | ||||||
|  | 		wgTasks.Add(1) | ||||||
|  | 		select { | ||||||
|  | 		case tasks <- t: | ||||||
|  | 			// Successfully enqueued
 | ||||||
|  | 		case <-ctx.Done(): | ||||||
|  | 			// Context canceled or timed out
 | ||||||
|  | 			wgTasks.Done() | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	worker := func() { | ||||||
|  | 		defer wgWorkers.Done() | ||||||
|  | 		for { | ||||||
|  | 			select { | ||||||
|  | 			case <-ctx.Done(): | ||||||
|  | 				return | ||||||
|  | 			case tk, ok := <-tasks: | ||||||
|  | 				if !ok { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 				 | ||||||
|  | 				if ctx.Err() != nil { | ||||||
|  | 					wgTasks.Done() | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 				 | ||||||
|  | 				mu.Lock() | ||||||
|  | 				if _, seen := visited[tk.url]; seen { | ||||||
|  | 					mu.Unlock() | ||||||
|  | 					wgTasks.Done() | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 				visited[tk.url] = struct{}{} | ||||||
|  | 				mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 				if visitedCallback != nil { | ||||||
|  | 					visitedCallback(tk.url, tk.depth, len(tasks)) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				// Create a context with timeout for this specific request
 | ||||||
|  | 				reqCtx, reqCancel := context.WithTimeout(ctx, 10*time.Second) | ||||||
|  | 				 | ||||||
|  | 				start := time.Now() | ||||||
|  | 				req, _ := http.NewRequestWithContext(reqCtx, http.MethodGet, tk.url, nil) | ||||||
|  | 				req.Header.Set("User-Agent", userAgent) | ||||||
|  | 				resp, err := client.Do(req) | ||||||
|  | 				reqCancel() // Cancel the request context
 | ||||||
|  | 				 | ||||||
|  | 				if err != nil { | ||||||
|  | 					mu.Lock() | ||||||
|  | 					errs[tk.url] = err | ||||||
|  | 					pageInfos[tk.url] = PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth} | ||||||
|  | 					mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 					if errorCallback != nil { | ||||||
|  | 						errorCallback(tk.url, err, len(tasks)) | ||||||
|  | 					} | ||||||
|  | 					wgTasks.Done() | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 				 | ||||||
|  | 				func() { | ||||||
|  | 					defer resp.Body.Close() | ||||||
|  | 					ct := resp.Header.Get("Content-Type") | ||||||
|  | 					// Default meta values
 | ||||||
|  | 					meta := PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth} | ||||||
|  | 					if resp.ContentLength > 0 { | ||||||
|  | 						meta.ContentLength = int(resp.ContentLength) | ||||||
|  | 					} | ||||||
|  | 					if resp.StatusCode != http.StatusOK || ct == "" || (ct != "text/html" && !hasPrefix(ct, "text/html")) { | ||||||
|  | 						mu.Lock() | ||||||
|  | 						pageInfos[tk.url] = meta | ||||||
|  | 						mu.Unlock() | ||||||
|  | 						return | ||||||
|  | 					} | ||||||
|  | 					 | ||||||
|  | 					// Limit body read to prevent memory issues
 | ||||||
|  | 					body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) // 1MB limit
 | ||||||
|  | 					meta.ContentLength = len(body) | ||||||
|  | 					meta.Title = htmlx.ExtractTitle(stringsReader(string(body))) | ||||||
|  | 					hrefs := htmlx.ExtractAnchors(stringsReader(string(body))) | ||||||
|  | 					 | ||||||
|  | 					var toEnqueue []string | ||||||
|  | 					for _, href := range hrefs { | ||||||
|  | 						abs, ok := urlutil.Normalize(tk.url, href) | ||||||
|  | 						if !ok { | ||||||
|  | 							continue | ||||||
|  | 						} | ||||||
|  | 						mu.Lock() | ||||||
|  | 						m, ok2 := outlinks[tk.url] | ||||||
|  | 						if !ok2 { | ||||||
|  | 							m = make(map[string]struct{}) | ||||||
|  | 							outlinks[tk.url] = m | ||||||
|  | 						} | ||||||
|  | 						m[abs] = struct{}{} | ||||||
|  | 						mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 						if tk.depth < maxDepth { | ||||||
|  | 							if !sameHostOnly || urlutil.SameHost(origin, abs) { | ||||||
|  | 								toEnqueue = append(toEnqueue, abs) | ||||||
|  | 							} | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 					 | ||||||
|  | 					// Limit the number of links to follow from a single page
 | ||||||
|  | 					if len(toEnqueue) > 50 { | ||||||
|  | 						toEnqueue = toEnqueue[:50] | ||||||
|  | 					} | ||||||
|  | 					 | ||||||
|  | 					for _, u := range toEnqueue { | ||||||
|  | 						enqueue(task{url: u, depth: tk.depth + 1}) | ||||||
|  | 					} | ||||||
|  | 					mu.Lock() | ||||||
|  | 					pageInfos[tk.url] = meta | ||||||
|  | 					mu.Unlock() | ||||||
|  | 				}() | ||||||
|  | 				wgTasks.Done() | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for i := 0; i < concurrency; i++ { | ||||||
|  | 		wgWorkers.Add(1) | ||||||
|  | 		go worker() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Close the tasks channel when all enqueued tasks are processed or timeout occurs
 | ||||||
|  | 	go func() { | ||||||
|  | 		done := make(chan struct{}) | ||||||
|  | 		go func() { | ||||||
|  | 			wgTasks.Wait() | ||||||
|  | 			close(done) | ||||||
|  | 		}() | ||||||
|  | 		 | ||||||
|  | 		select { | ||||||
|  | 		case <-done: | ||||||
|  | 			// All tasks completed
 | ||||||
|  | 		case <-ctx.Done(): | ||||||
|  | 			// Timeout occurred
 | ||||||
|  | 		} | ||||||
|  | 		close(tasks) | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	enqueue(task{url: startURL, depth: 0}) | ||||||
|  | 	 | ||||||
|  | 	// Wait for workers to finish or timeout
 | ||||||
|  | 	workersDone := make(chan struct{}) | ||||||
|  | 	go func() { | ||||||
|  | 		wgWorkers.Wait() | ||||||
|  | 		close(workersDone) | ||||||
|  | 	}() | ||||||
|  | 	 | ||||||
|  | 	select { | ||||||
|  | 	case <-workersDone: | ||||||
|  | 		// Workers finished normally
 | ||||||
|  | 	case <-ctx.Done(): | ||||||
|  | 		// Timeout occurred
 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return visited, errs, outlinks, pageInfos | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func hasPrefix(s string, prefix string) bool { | func hasPrefix(s string, prefix string) bool { | ||||||
| 	return len(s) >= len(prefix) && s[:len(prefix)] == prefix | 	return len(s) >= len(prefix) && s[:len(prefix)] == prefix | ||||||
| } | } | ||||||
|  |  | ||||||
							
								
								
									
										41
									
								
								main.go
								
								
								
								
							
							
						
						
									
										41
									
								
								main.go
								
								
								
								
							|  | @ -32,6 +32,8 @@ func main() { | ||||||
| 	var output string | 	var output string | ||||||
| 	var quiet bool | 	var quiet bool | ||||||
| 	var exportDir string | 	var exportDir string | ||||||
|  | 	var maxURLs int | ||||||
|  | 	var globalTimeout int | ||||||
| 
 | 
 | ||||||
| 	flag.StringVar(&target, "target", "", "Target site URL (e.g., https://example.com)") | 	flag.StringVar(&target, "target", "", "Target site URL (e.g., https://example.com)") | ||||||
| 	flag.IntVar(&concurrency, "concurrency", 10, "Number of concurrent workers") | 	flag.IntVar(&concurrency, "concurrency", 10, "Number of concurrent workers") | ||||||
|  | @ -42,6 +44,8 @@ func main() { | ||||||
| 	flag.StringVar(&output, "output", "text", "Output format: text|json") | 	flag.StringVar(&output, "output", "text", "Output format: text|json") | ||||||
| 	flag.BoolVar(&quiet, "quiet", false, "Suppress progress output") | 	flag.BoolVar(&quiet, "quiet", false, "Suppress progress output") | ||||||
| 	flag.StringVar(&exportDir, "export-dir", "exports", "Directory to write CSV/NDJSON exports into (set empty to disable)") | 	flag.StringVar(&exportDir, "export-dir", "exports", "Directory to write CSV/NDJSON exports into (set empty to disable)") | ||||||
|  | 	flag.IntVar(&maxURLs, "max-urls", 500, "Maximum number of URLs to crawl") | ||||||
|  | 	flag.IntVar(&globalTimeout, "global-timeout", 120, "Global timeout in seconds for the entire crawl") | ||||||
| 	flag.Parse() | 	flag.Parse() | ||||||
| 
 | 
 | ||||||
| 	if strings.TrimSpace(target) == "" { | 	if strings.TrimSpace(target) == "" { | ||||||
|  | @ -105,7 +109,11 @@ func main() { | ||||||
| 		currentURL.Store(u) | 		currentURL.Store(u) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	visited, crawlErrs, outlinks, pageInfo := crawler.Crawl(ctx, target, maxDepth, concurrency, sameHostOnly, client, userAgent, visitedCallback, errorCallback) | 	// Create a context with timeout for the entire crawl
 | ||||||
|  | 	ctxWithGlobalTimeout, cancelGlobal := context.WithTimeout(ctx, time.Duration(globalTimeout)*time.Second) | ||||||
|  | 	defer cancelGlobal() | ||||||
|  | 	 | ||||||
|  | 	visited, crawlErrs, outlinks, pageInfo := crawler.CrawlWithSafety(ctxWithGlobalTimeout, target, maxDepth, concurrency, sameHostOnly, client, userAgent, visitedCallback, errorCallback, maxURLs) | ||||||
| 
 | 
 | ||||||
| 	// Clear progress line before moving to next phase
 | 	// Clear progress line before moving to next phase
 | ||||||
| 	if !quiet { | 	if !quiet { | ||||||
|  | @ -195,8 +203,8 @@ func main() { | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Save JSON report to ./reports/<host>.json by default (ignored by git)
 | 	// Save JSON report to ./reports/<host>/report.json by default (ignored by git)
 | ||||||
| 	if err := saveReportJSON("reports", reports); err != nil { | 	if err := saveReportToSiteDir(reports); err != nil { | ||||||
| 		fmt.Fprintf(os.Stderr, "save report error: %v\n", err) | 		fmt.Fprintf(os.Stderr, "save report error: %v\n", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -341,21 +349,40 @@ func linkStatusesToNDJSON(r report.Report) []ndjsonItem { | ||||||
| 	return res | 	return res | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func saveReportJSON(baseDir string, r report.Report) error { | // saveReportToSiteDir saves the report to a subdirectory named after the site's hostname
 | ||||||
|  | // under the "reports" directory.
 | ||||||
|  | func saveReportToSiteDir(r report.Report) error { | ||||||
| 	u, err := url.Parse(r.Target) | 	u, err := url.Parse(r.Target) | ||||||
| 	if err != nil || u.Host == "" { | 	if err != nil || u.Host == "" { | ||||||
| 		return fmt.Errorf("invalid target for save: %s", r.Target) | 		return fmt.Errorf("invalid target for save: %s", r.Target) | ||||||
| 	} | 	} | ||||||
| 	if err := os.MkdirAll(baseDir, 0o755); err != nil { | 	 | ||||||
|  | 	// Create base reports directory
 | ||||||
|  | 	reportsDir := "reports" | ||||||
|  | 	if err := os.MkdirAll(reportsDir, 0o755); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	path := filepath.Join(baseDir, u.Host+".json") | 	 | ||||||
|  | 	// Create subdirectory for this site
 | ||||||
|  | 	siteDir := filepath.Join(reportsDir, u.Host) | ||||||
|  | 	if err := os.MkdirAll(siteDir, 0o755); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	 | ||||||
|  | 	// Save report to site subdirectory
 | ||||||
|  | 	path := filepath.Join(siteDir, "report.json") | ||||||
| 	f, err := os.Create(path) | 	f, err := os.Create(path) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	defer f.Close() | 	defer f.Close() | ||||||
|  | 	 | ||||||
| 	enc := json.NewEncoder(f) | 	enc := json.NewEncoder(f) | ||||||
| 	enc.SetIndent("", "  ") | 	enc.SetIndent("", "  ") | ||||||
| 	return enc.Encode(r) | 	if err := enc.Encode(r); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	 | ||||||
|  | 	fmt.Fprintf(os.Stderr, "Report saved to %s\n", path) | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue