Add safer crawler with timeouts and site-specific report directories

This commit is contained in:
Colin 2025-09-22 17:25:47 -04:00
parent bbb7808d1f
commit f2ae09dd78
Signed by: colin
SSH Key Fingerprint: SHA256:nRPCQTeMFLdGytxRQmPVK9VXY3/ePKQ5lGRyJhT5DY8
3 changed files with 501 additions and 7670 deletions

View File

@ -0,0 +1,240 @@
package crawlersafe
import (
"context"
"io"
"net/http"
"sync"
"time"
"urlcrawler/internal/htmlx"
"urlcrawler/internal/urlutil"
)
type task struct {
url string
depth int
}
type PageInfo struct {
Title string
ResponseTimeMs int64
ContentLength int
Depth int
}
// CrawlWithSafety is a safer version of the crawler that adds a global timeout
// and limits the maximum number of URLs to process to prevent infinite loops.
func CrawlWithSafety(ctx context.Context, startURL string, maxDepth int, concurrency int,
sameHostOnly bool, client *http.Client, userAgent string,
visitedCallback func(string, int, int), errorCallback func(string, error, int),
maxURLs int, timeoutSeconds int) (map[string]struct{}, map[string]error, map[string]map[string]struct{}, map[string]PageInfo) {
// Create a context with timeout
ctxWithTimeout, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
defer cancel()
visited := make(map[string]struct{})
errs := make(map[string]error)
outlinks := make(map[string]map[string]struct{})
pageInfos := make(map[string]PageInfo)
var mu sync.Mutex
var urlCounter int
origin := urlutil.Origin(startURL)
tasks := make(chan task, concurrency*2)
wgWorkers := sync.WaitGroup{}
wgTasks := sync.WaitGroup{}
enqueue := func(t task) {
mu.Lock()
defer mu.Unlock()
// Check if we've reached the URL limit
if urlCounter >= maxURLs {
return
}
urlCounter++
wgTasks.Add(1)
select {
case tasks <- t:
// Successfully enqueued
case <-ctxWithTimeout.Done():
// Context canceled or timed out
wgTasks.Done()
}
}
worker := func() {
defer wgWorkers.Done()
for {
select {
case <-ctxWithTimeout.Done():
return
case tk, ok := <-tasks:
if !ok {
return
}
if ctxWithTimeout.Err() != nil {
wgTasks.Done()
return
}
mu.Lock()
if _, seen := visited[tk.url]; seen {
mu.Unlock()
wgTasks.Done()
continue
}
visited[tk.url] = struct{}{}
mu.Unlock()
if visitedCallback != nil {
visitedCallback(tk.url, tk.depth, len(tasks))
}
// Create a context with timeout for this specific request
reqCtx, reqCancel := context.WithTimeout(ctxWithTimeout, 10*time.Second)
start := time.Now()
req, _ := http.NewRequestWithContext(reqCtx, http.MethodGet, tk.url, nil)
req.Header.Set("User-Agent", userAgent)
resp, err := client.Do(req)
reqCancel() // Cancel the request context
if err != nil {
mu.Lock()
errs[tk.url] = err
pageInfos[tk.url] = PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth}
mu.Unlock()
if errorCallback != nil {
errorCallback(tk.url, err, len(tasks))
}
wgTasks.Done()
continue
}
func() {
defer resp.Body.Close()
ct := resp.Header.Get("Content-Type")
// Default meta values
meta := PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth}
if resp.ContentLength > 0 {
meta.ContentLength = int(resp.ContentLength)
}
if resp.StatusCode != http.StatusOK || ct == "" || (ct != "text/html" && !hasPrefix(ct, "text/html")) {
mu.Lock()
pageInfos[tk.url] = meta
mu.Unlock()
return
}
// Limit body read to prevent memory issues
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) // 1MB limit
meta.ContentLength = len(body)
meta.Title = htmlx.ExtractTitle(stringsReader(string(body)))
hrefs := htmlx.ExtractAnchors(stringsReader(string(body)))
var toEnqueue []string
for _, href := range hrefs {
abs, ok := urlutil.Normalize(tk.url, href)
if !ok {
continue
}
mu.Lock()
m, ok2 := outlinks[tk.url]
if !ok2 {
m = make(map[string]struct{})
outlinks[tk.url] = m
}
m[abs] = struct{}{}
mu.Unlock()
if tk.depth < maxDepth {
if !sameHostOnly || urlutil.SameHost(origin, abs) {
toEnqueue = append(toEnqueue, abs)
}
}
}
// Limit the number of links to follow from a single page
if len(toEnqueue) > 50 {
toEnqueue = toEnqueue[:50]
}
for _, u := range toEnqueue {
enqueue(task{url: u, depth: tk.depth + 1})
}
mu.Lock()
pageInfos[tk.url] = meta
mu.Unlock()
}()
wgTasks.Done()
}
}
}
for i := 0; i < concurrency; i++ {
wgWorkers.Add(1)
go worker()
}
// Close the tasks channel when all enqueued tasks are processed or timeout occurs
go func() {
done := make(chan struct{})
go func() {
wgTasks.Wait()
close(done)
}()
select {
case <-done:
// All tasks completed
case <-ctxWithTimeout.Done():
// Timeout occurred
}
close(tasks)
}()
enqueue(task{url: startURL, depth: 0})
// Wait for workers to finish or timeout
workersDone := make(chan struct{})
go func() {
wgWorkers.Wait()
close(workersDone)
}()
select {
case <-workersDone:
// Workers finished normally
case <-ctxWithTimeout.Done():
// Timeout occurred
}
return visited, errs, outlinks, pageInfos
}
func hasPrefix(s string, prefix string) bool {
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
}
// stringsReader avoids importing strings at package top for a single use.
func stringsReader(s string) io.Reader {
return &stringReader{str: s}
}
type stringReader struct{ str string }
func (r *stringReader) Read(p []byte) (int, error) {
if len(r.str) == 0 {
return 0, io.EOF
}
n := copy(p, r.str)
r.str = r.str[n:]
return n, nil
}

File diff suppressed because it is too large Load Diff

261
safe_crawler.go Normal file
View File

@ -0,0 +1,261 @@
package main
import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync/atomic"
"time"
"urlcrawler/internal/crawlersafe"
"urlcrawler/internal/linkcheck"
"urlcrawler/internal/report"
"urlcrawler/internal/sitemap"
"urlcrawler/internal/urlutil"
)
func main() {
var target string
var concurrency int
var timeout time.Duration
var maxDepth int
var userAgent string
var sameHostOnly bool
var output string
var quiet bool
var maxURLs int
var globalTimeout int
flag.StringVar(&target, "target", "", "Target site URL (e.g., https://example.com)")
flag.IntVar(&concurrency, "concurrency", 5, "Number of concurrent workers")
flag.DurationVar(&timeout, "timeout", 10*time.Second, "HTTP timeout per request")
flag.IntVar(&maxDepth, "max-depth", 2, "Maximum crawl depth (0=crawl only the start page)")
flag.StringVar(&userAgent, "user-agent", "urlcrawler/1.0", "User-Agent header value")
flag.BoolVar(&sameHostOnly, "same-host-only", true, "Limit crawl to the same host as target")
flag.StringVar(&output, "output", "text", "Output format: text|json")
flag.BoolVar(&quiet, "quiet", false, "Suppress progress output")
flag.IntVar(&maxURLs, "max-urls", 100, "Maximum number of URLs to crawl")
flag.IntVar(&globalTimeout, "global-timeout", 60, "Global timeout in seconds for the entire crawl")
flag.Parse()
if strings.TrimSpace(target) == "" {
fmt.Fprintln(os.Stderr, "-target is required")
flag.Usage()
os.Exit(2)
}
client := &http.Client{Timeout: timeout}
ctx := context.Background()
// Report metadata
started := time.Now()
meta := report.Metadata{StartedAt: started.UTC().Format(time.RFC3339)}
params := report.Params{
MaxDepth: maxDepth,
Concurrency: concurrency,
TimeoutMs: timeout.Milliseconds(),
UserAgent: userAgent,
SameHostOnly: sameHostOnly,
}
fmt.Fprintf(os.Stderr, "Starting crawl of %s (depth: %d, max URLs: %d, timeout: %ds)...\n",
target, maxDepth, maxURLs, globalTimeout)
// Setup progress counters
var urlsVisited, urlsErrored atomic.Int64
var currentURL atomic.Value // string
var pendingTasks atomic.Int64
// Start progress reporter if not in quiet mode
ctxWithCancel, cancel := context.WithCancel(ctx)
defer cancel()
if !quiet {
go func() {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
cu, _ := currentURL.Load().(string)
fmt.Fprintf(os.Stderr, "\rURLs visited: %d | Errors: %d | Pending: %d | Current: %s",
urlsVisited.Load(), urlsErrored.Load(), pendingTasks.Load(), truncateForTTY(cu, 90))
case <-ctxWithCancel.Done():
return
}
}
}()
}
// Progress callback functions
visitedCallback := func(u string, depth int, pending int) {
urlsVisited.Add(1)
pendingTasks.Store(int64(pending))
currentURL.Store(u)
}
errorCallback := func(u string, err error, pending int) {
urlsErrored.Add(1)
pendingTasks.Store(int64(pending))
currentURL.Store(u)
}
// Use the safer crawler with timeout and URL limits
visited, crawlErrs, outlinks, pageInfo := crawlersafe.CrawlWithSafety(
ctx, target, maxDepth, concurrency, sameHostOnly,
client, userAgent, visitedCallback, errorCallback,
maxURLs, globalTimeout)
// Clear progress line before moving to next phase
if !quiet {
fmt.Fprintf(os.Stderr, "\rCrawl complete! URLs visited: %d | Errors: %d\n",
urlsVisited.Load(), urlsErrored.Load())
}
fmt.Fprintf(os.Stderr, "Fetching sitemap...\n")
smURLs, err := sitemap.FetchAll(ctx, target, client, userAgent)
if err != nil && !errors.Is(err, sitemap.ErrNotFound) {
fmt.Fprintf(os.Stderr, "sitemap error: %v\n", err)
}
// Robots.txt summary (simple)
robots := report.RobotsSummary{}
robotsURL := urlutil.Origin(target) + "/robots.txt"
{
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, robotsURL, nil)
req.Header.Set("User-Agent", userAgent)
resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
robots.Present = true
robots.FetchedAt = time.Now().UTC().Format(time.RFC3339)
}
}
}
// Build set of all unique links discovered across pages for status checks
allLinks := make(map[string]struct{})
for _, m := range outlinks {
for u := range m {
allLinks[u] = struct{}{}
}
}
// Also include the visited pages themselves
for u := range visited {
allLinks[u] = struct{}{}
}
fmt.Fprintf(os.Stderr, "Checking %d links...\n", len(allLinks))
// Reset counters for link checking
urlsVisited.Store(0)
urlsErrored.Store(0)
// Progress callback functions for link checking
linkCheckCallback := func(ok bool) {
if ok {
urlsVisited.Add(1)
} else {
urlsErrored.Add(1)
}
}
checkResults := linkcheck.Check(ctx, allLinks, concurrency, client, userAgent, !quiet, linkCheckCallback)
// Clear progress line before finishing
if !quiet {
fmt.Fprintf(os.Stderr, "\rLink checking complete! OK: %d | Errors: %d\n",
urlsVisited.Load(), urlsErrored.Load())
}
finished := time.Now()
meta.FinishedAt = finished.UTC().Format(time.RFC3339)
meta.DurationMs = finished.Sub(started).Milliseconds()
fmt.Fprintf(os.Stderr, "Building report...\n")
// Convert pageInfo to report.PageMeta
pages := make(map[string]report.PageMeta, len(pageInfo))
for u, pi := range pageInfo {
pages[u] = report.PageMeta{
Title: pi.Title,
ResponseTimeMs: pi.ResponseTimeMs,
ContentLength: pi.ContentLength,
Depth: pi.Depth,
}
}
reports := report.Build(target, visited, smURLs, crawlErrs, checkResults, outlinks, meta, params, pages, robots)
// Save report to a subdirectory named after the site
if err := saveReportToSiteDir(reports); err != nil {
fmt.Fprintf(os.Stderr, "save report error: %v\n", err)
}
// Output the report
switch output {
case "json":
enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", " ")
_ = enc.Encode(reports)
default:
report.PrintText(os.Stdout, reports)
}
}
// truncateForTTY truncates s to max characters, replacing the tail with … if needed.
func truncateForTTY(s string, max int) string {
if max <= 0 || len(s) <= max {
return s
}
if max <= 1 {
return "…"
}
return s[:max-1] + "…"
}
// saveReportToSiteDir saves the report to a subdirectory named after the site's hostname
// under the "reports" directory.
func saveReportToSiteDir(r report.Report) error {
u, err := url.Parse(r.Target)
if err != nil || u.Host == "" {
return fmt.Errorf("invalid target for save: %s", r.Target)
}
// Create base reports directory
reportsDir := "reports"
if err := os.MkdirAll(reportsDir, 0o755); err != nil {
return err
}
// Create subdirectory for this site
siteDir := filepath.Join(reportsDir, u.Host)
if err := os.MkdirAll(siteDir, 0o755); err != nil {
return err
}
// Save report to site subdirectory
path := filepath.Join(siteDir, "report.json")
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(r); err != nil {
return err
}
fmt.Fprintf(os.Stderr, "Report saved to %s\n", path)
return nil
}