From 1cc705c5c75a62576d5f01f99e32f1bf6174aeb7 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 22 Sep 2025 17:31:13 -0400 Subject: [PATCH] Enhance crawler with safety features and site-specific report directories --- internal/crawler/crawler.go | 192 ++++++++++++++++++++++++++++++++++++ main.go | 41 ++++++-- 2 files changed, 226 insertions(+), 7 deletions(-) diff --git a/internal/crawler/crawler.go b/internal/crawler/crawler.go index a36eb11..6d30c30 100644 --- a/internal/crawler/crawler.go +++ b/internal/crawler/crawler.go @@ -148,6 +148,198 @@ func Crawl(ctx context.Context, startURL string, maxDepth int, concurrency int, return visited, errs, outlinks, pageInfos } +// CrawlWithSafety is a safer version of the crawler that adds a global timeout +// and limits the maximum number of URLs to process to prevent infinite loops. +func CrawlWithSafety(ctx context.Context, startURL string, maxDepth int, concurrency int, + sameHostOnly bool, client *http.Client, userAgent string, + visitedCallback func(string, int, int), errorCallback func(string, error, int), + maxURLs int) (map[string]struct{}, map[string]error, map[string]map[string]struct{}, map[string]PageInfo) { + + visited := make(map[string]struct{}) + errs := make(map[string]error) + outlinks := make(map[string]map[string]struct{}) + pageInfos := make(map[string]PageInfo) + var mu sync.Mutex + var urlCounter int + + origin := urlutil.Origin(startURL) + + tasks := make(chan task, concurrency*2) + wgWorkers := sync.WaitGroup{} + wgTasks := sync.WaitGroup{} + + enqueue := func(t task) { + mu.Lock() + defer mu.Unlock() + + // Check if we've reached the URL limit + if urlCounter >= maxURLs { + return + } + + urlCounter++ + wgTasks.Add(1) + select { + case tasks <- t: + // Successfully enqueued + case <-ctx.Done(): + // Context canceled or timed out + wgTasks.Done() + } + } + + worker := func() { + defer wgWorkers.Done() + for { + select { + case <-ctx.Done(): + return + case tk, ok := <-tasks: + if !ok { + return + } + + if ctx.Err() != nil { + wgTasks.Done() + return + } + + mu.Lock() + if _, seen := visited[tk.url]; seen { + mu.Unlock() + wgTasks.Done() + continue + } + visited[tk.url] = struct{}{} + mu.Unlock() + + if visitedCallback != nil { + visitedCallback(tk.url, tk.depth, len(tasks)) + } + + // Create a context with timeout for this specific request + reqCtx, reqCancel := context.WithTimeout(ctx, 10*time.Second) + + start := time.Now() + req, _ := http.NewRequestWithContext(reqCtx, http.MethodGet, tk.url, nil) + req.Header.Set("User-Agent", userAgent) + resp, err := client.Do(req) + reqCancel() // Cancel the request context + + if err != nil { + mu.Lock() + errs[tk.url] = err + pageInfos[tk.url] = PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth} + mu.Unlock() + + if errorCallback != nil { + errorCallback(tk.url, err, len(tasks)) + } + wgTasks.Done() + continue + } + + func() { + defer resp.Body.Close() + ct := resp.Header.Get("Content-Type") + // Default meta values + meta := PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth} + if resp.ContentLength > 0 { + meta.ContentLength = int(resp.ContentLength) + } + if resp.StatusCode != http.StatusOK || ct == "" || (ct != "text/html" && !hasPrefix(ct, "text/html")) { + mu.Lock() + pageInfos[tk.url] = meta + mu.Unlock() + return + } + + // Limit body read to prevent memory issues + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) // 1MB limit + meta.ContentLength = len(body) + meta.Title = htmlx.ExtractTitle(stringsReader(string(body))) + hrefs := htmlx.ExtractAnchors(stringsReader(string(body))) + + var toEnqueue []string + for _, href := range hrefs { + abs, ok := urlutil.Normalize(tk.url, href) + if !ok { + continue + } + mu.Lock() + m, ok2 := outlinks[tk.url] + if !ok2 { + m = make(map[string]struct{}) + outlinks[tk.url] = m + } + m[abs] = struct{}{} + mu.Unlock() + + if tk.depth < maxDepth { + if !sameHostOnly || urlutil.SameHost(origin, abs) { + toEnqueue = append(toEnqueue, abs) + } + } + } + + // Limit the number of links to follow from a single page + if len(toEnqueue) > 50 { + toEnqueue = toEnqueue[:50] + } + + for _, u := range toEnqueue { + enqueue(task{url: u, depth: tk.depth + 1}) + } + mu.Lock() + pageInfos[tk.url] = meta + mu.Unlock() + }() + wgTasks.Done() + } + } + } + + for i := 0; i < concurrency; i++ { + wgWorkers.Add(1) + go worker() + } + + // Close the tasks channel when all enqueued tasks are processed or timeout occurs + go func() { + done := make(chan struct{}) + go func() { + wgTasks.Wait() + close(done) + }() + + select { + case <-done: + // All tasks completed + case <-ctx.Done(): + // Timeout occurred + } + close(tasks) + }() + + enqueue(task{url: startURL, depth: 0}) + + // Wait for workers to finish or timeout + workersDone := make(chan struct{}) + go func() { + wgWorkers.Wait() + close(workersDone) + }() + + select { + case <-workersDone: + // Workers finished normally + case <-ctx.Done(): + // Timeout occurred + } + + return visited, errs, outlinks, pageInfos +} + func hasPrefix(s string, prefix string) bool { return len(s) >= len(prefix) && s[:len(prefix)] == prefix } diff --git a/main.go b/main.go index ccf6193..4b0ed9d 100644 --- a/main.go +++ b/main.go @@ -32,6 +32,8 @@ func main() { var output string var quiet bool var exportDir string + var maxURLs int + var globalTimeout int flag.StringVar(&target, "target", "", "Target site URL (e.g., https://example.com)") flag.IntVar(&concurrency, "concurrency", 10, "Number of concurrent workers") @@ -42,6 +44,8 @@ func main() { flag.StringVar(&output, "output", "text", "Output format: text|json") flag.BoolVar(&quiet, "quiet", false, "Suppress progress output") flag.StringVar(&exportDir, "export-dir", "exports", "Directory to write CSV/NDJSON exports into (set empty to disable)") + flag.IntVar(&maxURLs, "max-urls", 500, "Maximum number of URLs to crawl") + flag.IntVar(&globalTimeout, "global-timeout", 120, "Global timeout in seconds for the entire crawl") flag.Parse() if strings.TrimSpace(target) == "" { @@ -105,7 +109,11 @@ func main() { currentURL.Store(u) } - visited, crawlErrs, outlinks, pageInfo := crawler.Crawl(ctx, target, maxDepth, concurrency, sameHostOnly, client, userAgent, visitedCallback, errorCallback) + // Create a context with timeout for the entire crawl + ctxWithGlobalTimeout, cancelGlobal := context.WithTimeout(ctx, time.Duration(globalTimeout)*time.Second) + defer cancelGlobal() + + visited, crawlErrs, outlinks, pageInfo := crawler.CrawlWithSafety(ctxWithGlobalTimeout, target, maxDepth, concurrency, sameHostOnly, client, userAgent, visitedCallback, errorCallback, maxURLs) // Clear progress line before moving to next phase if !quiet { @@ -195,8 +203,8 @@ func main() { } } - // Save JSON report to ./reports/.json by default (ignored by git) - if err := saveReportJSON("reports", reports); err != nil { + // Save JSON report to ./reports//report.json by default (ignored by git) + if err := saveReportToSiteDir(reports); err != nil { fmt.Fprintf(os.Stderr, "save report error: %v\n", err) } @@ -341,21 +349,40 @@ func linkStatusesToNDJSON(r report.Report) []ndjsonItem { return res } -func saveReportJSON(baseDir string, r report.Report) error { +// saveReportToSiteDir saves the report to a subdirectory named after the site's hostname +// under the "reports" directory. +func saveReportToSiteDir(r report.Report) error { u, err := url.Parse(r.Target) if err != nil || u.Host == "" { return fmt.Errorf("invalid target for save: %s", r.Target) } - if err := os.MkdirAll(baseDir, 0o755); err != nil { + + // Create base reports directory + reportsDir := "reports" + if err := os.MkdirAll(reportsDir, 0o755); err != nil { return err } - path := filepath.Join(baseDir, u.Host+".json") + + // Create subdirectory for this site + siteDir := filepath.Join(reportsDir, u.Host) + if err := os.MkdirAll(siteDir, 0o755); err != nil { + return err + } + + // Save report to site subdirectory + path := filepath.Join(siteDir, "report.json") f, err := os.Create(path) if err != nil { return err } defer f.Close() + enc := json.NewEncoder(f) enc.SetIndent("", " ") - return enc.Encode(r) + if err := enc.Encode(r); err != nil { + return err + } + + fmt.Fprintf(os.Stderr, "Report saved to %s\n", path) + return nil }