Enhance crawler with safety features and site-specific report directories
This commit is contained in:
parent
f2ae09dd78
commit
1cc705c5c7
|
@ -148,6 +148,198 @@ func Crawl(ctx context.Context, startURL string, maxDepth int, concurrency int,
|
|||
return visited, errs, outlinks, pageInfos
|
||||
}
|
||||
|
||||
// CrawlWithSafety is a safer version of the crawler that adds a global timeout
|
||||
// and limits the maximum number of URLs to process to prevent infinite loops.
|
||||
func CrawlWithSafety(ctx context.Context, startURL string, maxDepth int, concurrency int,
|
||||
sameHostOnly bool, client *http.Client, userAgent string,
|
||||
visitedCallback func(string, int, int), errorCallback func(string, error, int),
|
||||
maxURLs int) (map[string]struct{}, map[string]error, map[string]map[string]struct{}, map[string]PageInfo) {
|
||||
|
||||
visited := make(map[string]struct{})
|
||||
errs := make(map[string]error)
|
||||
outlinks := make(map[string]map[string]struct{})
|
||||
pageInfos := make(map[string]PageInfo)
|
||||
var mu sync.Mutex
|
||||
var urlCounter int
|
||||
|
||||
origin := urlutil.Origin(startURL)
|
||||
|
||||
tasks := make(chan task, concurrency*2)
|
||||
wgWorkers := sync.WaitGroup{}
|
||||
wgTasks := sync.WaitGroup{}
|
||||
|
||||
enqueue := func(t task) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Check if we've reached the URL limit
|
||||
if urlCounter >= maxURLs {
|
||||
return
|
||||
}
|
||||
|
||||
urlCounter++
|
||||
wgTasks.Add(1)
|
||||
select {
|
||||
case tasks <- t:
|
||||
// Successfully enqueued
|
||||
case <-ctx.Done():
|
||||
// Context canceled or timed out
|
||||
wgTasks.Done()
|
||||
}
|
||||
}
|
||||
|
||||
worker := func() {
|
||||
defer wgWorkers.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case tk, ok := <-tasks:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
wgTasks.Done()
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
if _, seen := visited[tk.url]; seen {
|
||||
mu.Unlock()
|
||||
wgTasks.Done()
|
||||
continue
|
||||
}
|
||||
visited[tk.url] = struct{}{}
|
||||
mu.Unlock()
|
||||
|
||||
if visitedCallback != nil {
|
||||
visitedCallback(tk.url, tk.depth, len(tasks))
|
||||
}
|
||||
|
||||
// Create a context with timeout for this specific request
|
||||
reqCtx, reqCancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
|
||||
start := time.Now()
|
||||
req, _ := http.NewRequestWithContext(reqCtx, http.MethodGet, tk.url, nil)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
resp, err := client.Do(req)
|
||||
reqCancel() // Cancel the request context
|
||||
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
errs[tk.url] = err
|
||||
pageInfos[tk.url] = PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth}
|
||||
mu.Unlock()
|
||||
|
||||
if errorCallback != nil {
|
||||
errorCallback(tk.url, err, len(tasks))
|
||||
}
|
||||
wgTasks.Done()
|
||||
continue
|
||||
}
|
||||
|
||||
func() {
|
||||
defer resp.Body.Close()
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
// Default meta values
|
||||
meta := PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth}
|
||||
if resp.ContentLength > 0 {
|
||||
meta.ContentLength = int(resp.ContentLength)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK || ct == "" || (ct != "text/html" && !hasPrefix(ct, "text/html")) {
|
||||
mu.Lock()
|
||||
pageInfos[tk.url] = meta
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Limit body read to prevent memory issues
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) // 1MB limit
|
||||
meta.ContentLength = len(body)
|
||||
meta.Title = htmlx.ExtractTitle(stringsReader(string(body)))
|
||||
hrefs := htmlx.ExtractAnchors(stringsReader(string(body)))
|
||||
|
||||
var toEnqueue []string
|
||||
for _, href := range hrefs {
|
||||
abs, ok := urlutil.Normalize(tk.url, href)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
mu.Lock()
|
||||
m, ok2 := outlinks[tk.url]
|
||||
if !ok2 {
|
||||
m = make(map[string]struct{})
|
||||
outlinks[tk.url] = m
|
||||
}
|
||||
m[abs] = struct{}{}
|
||||
mu.Unlock()
|
||||
|
||||
if tk.depth < maxDepth {
|
||||
if !sameHostOnly || urlutil.SameHost(origin, abs) {
|
||||
toEnqueue = append(toEnqueue, abs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Limit the number of links to follow from a single page
|
||||
if len(toEnqueue) > 50 {
|
||||
toEnqueue = toEnqueue[:50]
|
||||
}
|
||||
|
||||
for _, u := range toEnqueue {
|
||||
enqueue(task{url: u, depth: tk.depth + 1})
|
||||
}
|
||||
mu.Lock()
|
||||
pageInfos[tk.url] = meta
|
||||
mu.Unlock()
|
||||
}()
|
||||
wgTasks.Done()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wgWorkers.Add(1)
|
||||
go worker()
|
||||
}
|
||||
|
||||
// Close the tasks channel when all enqueued tasks are processed or timeout occurs
|
||||
go func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wgTasks.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// All tasks completed
|
||||
case <-ctx.Done():
|
||||
// Timeout occurred
|
||||
}
|
||||
close(tasks)
|
||||
}()
|
||||
|
||||
enqueue(task{url: startURL, depth: 0})
|
||||
|
||||
// Wait for workers to finish or timeout
|
||||
workersDone := make(chan struct{})
|
||||
go func() {
|
||||
wgWorkers.Wait()
|
||||
close(workersDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-workersDone:
|
||||
// Workers finished normally
|
||||
case <-ctx.Done():
|
||||
// Timeout occurred
|
||||
}
|
||||
|
||||
return visited, errs, outlinks, pageInfos
|
||||
}
|
||||
|
||||
func hasPrefix(s string, prefix string) bool {
|
||||
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
|
||||
}
|
||||
|
|
41
main.go
41
main.go
|
@ -32,6 +32,8 @@ func main() {
|
|||
var output string
|
||||
var quiet bool
|
||||
var exportDir string
|
||||
var maxURLs int
|
||||
var globalTimeout int
|
||||
|
||||
flag.StringVar(&target, "target", "", "Target site URL (e.g., https://example.com)")
|
||||
flag.IntVar(&concurrency, "concurrency", 10, "Number of concurrent workers")
|
||||
|
@ -42,6 +44,8 @@ func main() {
|
|||
flag.StringVar(&output, "output", "text", "Output format: text|json")
|
||||
flag.BoolVar(&quiet, "quiet", false, "Suppress progress output")
|
||||
flag.StringVar(&exportDir, "export-dir", "exports", "Directory to write CSV/NDJSON exports into (set empty to disable)")
|
||||
flag.IntVar(&maxURLs, "max-urls", 500, "Maximum number of URLs to crawl")
|
||||
flag.IntVar(&globalTimeout, "global-timeout", 120, "Global timeout in seconds for the entire crawl")
|
||||
flag.Parse()
|
||||
|
||||
if strings.TrimSpace(target) == "" {
|
||||
|
@ -105,7 +109,11 @@ func main() {
|
|||
currentURL.Store(u)
|
||||
}
|
||||
|
||||
visited, crawlErrs, outlinks, pageInfo := crawler.Crawl(ctx, target, maxDepth, concurrency, sameHostOnly, client, userAgent, visitedCallback, errorCallback)
|
||||
// Create a context with timeout for the entire crawl
|
||||
ctxWithGlobalTimeout, cancelGlobal := context.WithTimeout(ctx, time.Duration(globalTimeout)*time.Second)
|
||||
defer cancelGlobal()
|
||||
|
||||
visited, crawlErrs, outlinks, pageInfo := crawler.CrawlWithSafety(ctxWithGlobalTimeout, target, maxDepth, concurrency, sameHostOnly, client, userAgent, visitedCallback, errorCallback, maxURLs)
|
||||
|
||||
// Clear progress line before moving to next phase
|
||||
if !quiet {
|
||||
|
@ -195,8 +203,8 @@ func main() {
|
|||
}
|
||||
}
|
||||
|
||||
// Save JSON report to ./reports/<host>.json by default (ignored by git)
|
||||
if err := saveReportJSON("reports", reports); err != nil {
|
||||
// Save JSON report to ./reports/<host>/report.json by default (ignored by git)
|
||||
if err := saveReportToSiteDir(reports); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "save report error: %v\n", err)
|
||||
}
|
||||
|
||||
|
@ -341,21 +349,40 @@ func linkStatusesToNDJSON(r report.Report) []ndjsonItem {
|
|||
return res
|
||||
}
|
||||
|
||||
func saveReportJSON(baseDir string, r report.Report) error {
|
||||
// saveReportToSiteDir saves the report to a subdirectory named after the site's hostname
|
||||
// under the "reports" directory.
|
||||
func saveReportToSiteDir(r report.Report) error {
|
||||
u, err := url.Parse(r.Target)
|
||||
if err != nil || u.Host == "" {
|
||||
return fmt.Errorf("invalid target for save: %s", r.Target)
|
||||
}
|
||||
if err := os.MkdirAll(baseDir, 0o755); err != nil {
|
||||
|
||||
// Create base reports directory
|
||||
reportsDir := "reports"
|
||||
if err := os.MkdirAll(reportsDir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
path := filepath.Join(baseDir, u.Host+".json")
|
||||
|
||||
// Create subdirectory for this site
|
||||
siteDir := filepath.Join(reportsDir, u.Host)
|
||||
if err := os.MkdirAll(siteDir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save report to site subdirectory
|
||||
path := filepath.Join(siteDir, "report.json")
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
return enc.Encode(r)
|
||||
if err := enc.Encode(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Report saved to %s\n", path)
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue