Enhance crawler with safety features and site-specific report directories

This commit is contained in:
Colin 2025-09-22 17:31:13 -04:00
parent f2ae09dd78
commit 1cc705c5c7
Signed by: colin
SSH Key Fingerprint: SHA256:nRPCQTeMFLdGytxRQmPVK9VXY3/ePKQ5lGRyJhT5DY8
2 changed files with 226 additions and 7 deletions

View File

@ -148,6 +148,198 @@ func Crawl(ctx context.Context, startURL string, maxDepth int, concurrency int,
return visited, errs, outlinks, pageInfos
}
// CrawlWithSafety is a safer version of the crawler that adds a global timeout
// and limits the maximum number of URLs to process to prevent infinite loops.
func CrawlWithSafety(ctx context.Context, startURL string, maxDepth int, concurrency int,
sameHostOnly bool, client *http.Client, userAgent string,
visitedCallback func(string, int, int), errorCallback func(string, error, int),
maxURLs int) (map[string]struct{}, map[string]error, map[string]map[string]struct{}, map[string]PageInfo) {
visited := make(map[string]struct{})
errs := make(map[string]error)
outlinks := make(map[string]map[string]struct{})
pageInfos := make(map[string]PageInfo)
var mu sync.Mutex
var urlCounter int
origin := urlutil.Origin(startURL)
tasks := make(chan task, concurrency*2)
wgWorkers := sync.WaitGroup{}
wgTasks := sync.WaitGroup{}
enqueue := func(t task) {
mu.Lock()
defer mu.Unlock()
// Check if we've reached the URL limit
if urlCounter >= maxURLs {
return
}
urlCounter++
wgTasks.Add(1)
select {
case tasks <- t:
// Successfully enqueued
case <-ctx.Done():
// Context canceled or timed out
wgTasks.Done()
}
}
worker := func() {
defer wgWorkers.Done()
for {
select {
case <-ctx.Done():
return
case tk, ok := <-tasks:
if !ok {
return
}
if ctx.Err() != nil {
wgTasks.Done()
return
}
mu.Lock()
if _, seen := visited[tk.url]; seen {
mu.Unlock()
wgTasks.Done()
continue
}
visited[tk.url] = struct{}{}
mu.Unlock()
if visitedCallback != nil {
visitedCallback(tk.url, tk.depth, len(tasks))
}
// Create a context with timeout for this specific request
reqCtx, reqCancel := context.WithTimeout(ctx, 10*time.Second)
start := time.Now()
req, _ := http.NewRequestWithContext(reqCtx, http.MethodGet, tk.url, nil)
req.Header.Set("User-Agent", userAgent)
resp, err := client.Do(req)
reqCancel() // Cancel the request context
if err != nil {
mu.Lock()
errs[tk.url] = err
pageInfos[tk.url] = PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth}
mu.Unlock()
if errorCallback != nil {
errorCallback(tk.url, err, len(tasks))
}
wgTasks.Done()
continue
}
func() {
defer resp.Body.Close()
ct := resp.Header.Get("Content-Type")
// Default meta values
meta := PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth}
if resp.ContentLength > 0 {
meta.ContentLength = int(resp.ContentLength)
}
if resp.StatusCode != http.StatusOK || ct == "" || (ct != "text/html" && !hasPrefix(ct, "text/html")) {
mu.Lock()
pageInfos[tk.url] = meta
mu.Unlock()
return
}
// Limit body read to prevent memory issues
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) // 1MB limit
meta.ContentLength = len(body)
meta.Title = htmlx.ExtractTitle(stringsReader(string(body)))
hrefs := htmlx.ExtractAnchors(stringsReader(string(body)))
var toEnqueue []string
for _, href := range hrefs {
abs, ok := urlutil.Normalize(tk.url, href)
if !ok {
continue
}
mu.Lock()
m, ok2 := outlinks[tk.url]
if !ok2 {
m = make(map[string]struct{})
outlinks[tk.url] = m
}
m[abs] = struct{}{}
mu.Unlock()
if tk.depth < maxDepth {
if !sameHostOnly || urlutil.SameHost(origin, abs) {
toEnqueue = append(toEnqueue, abs)
}
}
}
// Limit the number of links to follow from a single page
if len(toEnqueue) > 50 {
toEnqueue = toEnqueue[:50]
}
for _, u := range toEnqueue {
enqueue(task{url: u, depth: tk.depth + 1})
}
mu.Lock()
pageInfos[tk.url] = meta
mu.Unlock()
}()
wgTasks.Done()
}
}
}
for i := 0; i < concurrency; i++ {
wgWorkers.Add(1)
go worker()
}
// Close the tasks channel when all enqueued tasks are processed or timeout occurs
go func() {
done := make(chan struct{})
go func() {
wgTasks.Wait()
close(done)
}()
select {
case <-done:
// All tasks completed
case <-ctx.Done():
// Timeout occurred
}
close(tasks)
}()
enqueue(task{url: startURL, depth: 0})
// Wait for workers to finish or timeout
workersDone := make(chan struct{})
go func() {
wgWorkers.Wait()
close(workersDone)
}()
select {
case <-workersDone:
// Workers finished normally
case <-ctx.Done():
// Timeout occurred
}
return visited, errs, outlinks, pageInfos
}
func hasPrefix(s string, prefix string) bool {
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
}

41
main.go
View File

@ -32,6 +32,8 @@ func main() {
var output string
var quiet bool
var exportDir string
var maxURLs int
var globalTimeout int
flag.StringVar(&target, "target", "", "Target site URL (e.g., https://example.com)")
flag.IntVar(&concurrency, "concurrency", 10, "Number of concurrent workers")
@ -42,6 +44,8 @@ func main() {
flag.StringVar(&output, "output", "text", "Output format: text|json")
flag.BoolVar(&quiet, "quiet", false, "Suppress progress output")
flag.StringVar(&exportDir, "export-dir", "exports", "Directory to write CSV/NDJSON exports into (set empty to disable)")
flag.IntVar(&maxURLs, "max-urls", 500, "Maximum number of URLs to crawl")
flag.IntVar(&globalTimeout, "global-timeout", 120, "Global timeout in seconds for the entire crawl")
flag.Parse()
if strings.TrimSpace(target) == "" {
@ -105,7 +109,11 @@ func main() {
currentURL.Store(u)
}
visited, crawlErrs, outlinks, pageInfo := crawler.Crawl(ctx, target, maxDepth, concurrency, sameHostOnly, client, userAgent, visitedCallback, errorCallback)
// Create a context with timeout for the entire crawl
ctxWithGlobalTimeout, cancelGlobal := context.WithTimeout(ctx, time.Duration(globalTimeout)*time.Second)
defer cancelGlobal()
visited, crawlErrs, outlinks, pageInfo := crawler.CrawlWithSafety(ctxWithGlobalTimeout, target, maxDepth, concurrency, sameHostOnly, client, userAgent, visitedCallback, errorCallback, maxURLs)
// Clear progress line before moving to next phase
if !quiet {
@ -195,8 +203,8 @@ func main() {
}
}
// Save JSON report to ./reports/<host>.json by default (ignored by git)
if err := saveReportJSON("reports", reports); err != nil {
// Save JSON report to ./reports/<host>/report.json by default (ignored by git)
if err := saveReportToSiteDir(reports); err != nil {
fmt.Fprintf(os.Stderr, "save report error: %v\n", err)
}
@ -341,21 +349,40 @@ func linkStatusesToNDJSON(r report.Report) []ndjsonItem {
return res
}
func saveReportJSON(baseDir string, r report.Report) error {
// saveReportToSiteDir saves the report to a subdirectory named after the site's hostname
// under the "reports" directory.
func saveReportToSiteDir(r report.Report) error {
u, err := url.Parse(r.Target)
if err != nil || u.Host == "" {
return fmt.Errorf("invalid target for save: %s", r.Target)
}
if err := os.MkdirAll(baseDir, 0o755); err != nil {
// Create base reports directory
reportsDir := "reports"
if err := os.MkdirAll(reportsDir, 0o755); err != nil {
return err
}
path := filepath.Join(baseDir, u.Host+".json")
// Create subdirectory for this site
siteDir := filepath.Join(reportsDir, u.Host)
if err := os.MkdirAll(siteDir, 0o755); err != nil {
return err
}
// Save report to site subdirectory
path := filepath.Join(siteDir, "report.json")
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
return enc.Encode(r)
if err := enc.Encode(r); err != nil {
return err
}
fmt.Fprintf(os.Stderr, "Report saved to %s\n", path)
return nil
}