Add safer crawler with timeouts and site-specific report directories
This commit is contained in:
parent
bbb7808d1f
commit
f2ae09dd78
|
|
@ -0,0 +1,240 @@
|
||||||
|
package crawlersafe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"urlcrawler/internal/htmlx"
|
||||||
|
"urlcrawler/internal/urlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type task struct {
|
||||||
|
url string
|
||||||
|
depth int
|
||||||
|
}
|
||||||
|
|
||||||
|
type PageInfo struct {
|
||||||
|
Title string
|
||||||
|
ResponseTimeMs int64
|
||||||
|
ContentLength int
|
||||||
|
Depth int
|
||||||
|
}
|
||||||
|
|
||||||
|
// CrawlWithSafety is a safer version of the crawler that adds a global timeout
|
||||||
|
// and limits the maximum number of URLs to process to prevent infinite loops.
|
||||||
|
func CrawlWithSafety(ctx context.Context, startURL string, maxDepth int, concurrency int,
|
||||||
|
sameHostOnly bool, client *http.Client, userAgent string,
|
||||||
|
visitedCallback func(string, int, int), errorCallback func(string, error, int),
|
||||||
|
maxURLs int, timeoutSeconds int) (map[string]struct{}, map[string]error, map[string]map[string]struct{}, map[string]PageInfo) {
|
||||||
|
|
||||||
|
// Create a context with timeout
|
||||||
|
ctxWithTimeout, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
visited := make(map[string]struct{})
|
||||||
|
errs := make(map[string]error)
|
||||||
|
outlinks := make(map[string]map[string]struct{})
|
||||||
|
pageInfos := make(map[string]PageInfo)
|
||||||
|
var mu sync.Mutex
|
||||||
|
var urlCounter int
|
||||||
|
|
||||||
|
origin := urlutil.Origin(startURL)
|
||||||
|
|
||||||
|
tasks := make(chan task, concurrency*2)
|
||||||
|
wgWorkers := sync.WaitGroup{}
|
||||||
|
wgTasks := sync.WaitGroup{}
|
||||||
|
|
||||||
|
enqueue := func(t task) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
// Check if we've reached the URL limit
|
||||||
|
if urlCounter >= maxURLs {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
urlCounter++
|
||||||
|
wgTasks.Add(1)
|
||||||
|
select {
|
||||||
|
case tasks <- t:
|
||||||
|
// Successfully enqueued
|
||||||
|
case <-ctxWithTimeout.Done():
|
||||||
|
// Context canceled or timed out
|
||||||
|
wgTasks.Done()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
worker := func() {
|
||||||
|
defer wgWorkers.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctxWithTimeout.Done():
|
||||||
|
return
|
||||||
|
case tk, ok := <-tasks:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctxWithTimeout.Err() != nil {
|
||||||
|
wgTasks.Done()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
if _, seen := visited[tk.url]; seen {
|
||||||
|
mu.Unlock()
|
||||||
|
wgTasks.Done()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
visited[tk.url] = struct{}{}
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
if visitedCallback != nil {
|
||||||
|
visitedCallback(tk.url, tk.depth, len(tasks))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a context with timeout for this specific request
|
||||||
|
reqCtx, reqCancel := context.WithTimeout(ctxWithTimeout, 10*time.Second)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
req, _ := http.NewRequestWithContext(reqCtx, http.MethodGet, tk.url, nil)
|
||||||
|
req.Header.Set("User-Agent", userAgent)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
reqCancel() // Cancel the request context
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
mu.Lock()
|
||||||
|
errs[tk.url] = err
|
||||||
|
pageInfos[tk.url] = PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth}
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
if errorCallback != nil {
|
||||||
|
errorCallback(tk.url, err, len(tasks))
|
||||||
|
}
|
||||||
|
wgTasks.Done()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
func() {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
ct := resp.Header.Get("Content-Type")
|
||||||
|
// Default meta values
|
||||||
|
meta := PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth}
|
||||||
|
if resp.ContentLength > 0 {
|
||||||
|
meta.ContentLength = int(resp.ContentLength)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK || ct == "" || (ct != "text/html" && !hasPrefix(ct, "text/html")) {
|
||||||
|
mu.Lock()
|
||||||
|
pageInfos[tk.url] = meta
|
||||||
|
mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit body read to prevent memory issues
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) // 1MB limit
|
||||||
|
meta.ContentLength = len(body)
|
||||||
|
meta.Title = htmlx.ExtractTitle(stringsReader(string(body)))
|
||||||
|
hrefs := htmlx.ExtractAnchors(stringsReader(string(body)))
|
||||||
|
|
||||||
|
var toEnqueue []string
|
||||||
|
for _, href := range hrefs {
|
||||||
|
abs, ok := urlutil.Normalize(tk.url, href)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
m, ok2 := outlinks[tk.url]
|
||||||
|
if !ok2 {
|
||||||
|
m = make(map[string]struct{})
|
||||||
|
outlinks[tk.url] = m
|
||||||
|
}
|
||||||
|
m[abs] = struct{}{}
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
if tk.depth < maxDepth {
|
||||||
|
if !sameHostOnly || urlutil.SameHost(origin, abs) {
|
||||||
|
toEnqueue = append(toEnqueue, abs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit the number of links to follow from a single page
|
||||||
|
if len(toEnqueue) > 50 {
|
||||||
|
toEnqueue = toEnqueue[:50]
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, u := range toEnqueue {
|
||||||
|
enqueue(task{url: u, depth: tk.depth + 1})
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
pageInfos[tk.url] = meta
|
||||||
|
mu.Unlock()
|
||||||
|
}()
|
||||||
|
wgTasks.Done()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < concurrency; i++ {
|
||||||
|
wgWorkers.Add(1)
|
||||||
|
go worker()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the tasks channel when all enqueued tasks are processed or timeout occurs
|
||||||
|
go func() {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wgTasks.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// All tasks completed
|
||||||
|
case <-ctxWithTimeout.Done():
|
||||||
|
// Timeout occurred
|
||||||
|
}
|
||||||
|
close(tasks)
|
||||||
|
}()
|
||||||
|
|
||||||
|
enqueue(task{url: startURL, depth: 0})
|
||||||
|
|
||||||
|
// Wait for workers to finish or timeout
|
||||||
|
workersDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wgWorkers.Wait()
|
||||||
|
close(workersDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-workersDone:
|
||||||
|
// Workers finished normally
|
||||||
|
case <-ctxWithTimeout.Done():
|
||||||
|
// Timeout occurred
|
||||||
|
}
|
||||||
|
|
||||||
|
return visited, errs, outlinks, pageInfos
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasPrefix(s string, prefix string) bool {
|
||||||
|
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// stringsReader avoids importing strings at package top for a single use.
|
||||||
|
func stringsReader(s string) io.Reader {
|
||||||
|
return &stringReader{str: s}
|
||||||
|
}
|
||||||
|
|
||||||
|
type stringReader struct{ str string }
|
||||||
|
|
||||||
|
func (r *stringReader) Read(p []byte) (int, error) {
|
||||||
|
if len(r.str) == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n := copy(p, r.str)
|
||||||
|
r.str = r.str[n:]
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
7670
report.json
7670
report.json
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,261 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"urlcrawler/internal/crawlersafe"
|
||||||
|
"urlcrawler/internal/linkcheck"
|
||||||
|
"urlcrawler/internal/report"
|
||||||
|
"urlcrawler/internal/sitemap"
|
||||||
|
"urlcrawler/internal/urlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var target string
|
||||||
|
var concurrency int
|
||||||
|
var timeout time.Duration
|
||||||
|
var maxDepth int
|
||||||
|
var userAgent string
|
||||||
|
var sameHostOnly bool
|
||||||
|
var output string
|
||||||
|
var quiet bool
|
||||||
|
var maxURLs int
|
||||||
|
var globalTimeout int
|
||||||
|
|
||||||
|
flag.StringVar(&target, "target", "", "Target site URL (e.g., https://example.com)")
|
||||||
|
flag.IntVar(&concurrency, "concurrency", 5, "Number of concurrent workers")
|
||||||
|
flag.DurationVar(&timeout, "timeout", 10*time.Second, "HTTP timeout per request")
|
||||||
|
flag.IntVar(&maxDepth, "max-depth", 2, "Maximum crawl depth (0=crawl only the start page)")
|
||||||
|
flag.StringVar(&userAgent, "user-agent", "urlcrawler/1.0", "User-Agent header value")
|
||||||
|
flag.BoolVar(&sameHostOnly, "same-host-only", true, "Limit crawl to the same host as target")
|
||||||
|
flag.StringVar(&output, "output", "text", "Output format: text|json")
|
||||||
|
flag.BoolVar(&quiet, "quiet", false, "Suppress progress output")
|
||||||
|
flag.IntVar(&maxURLs, "max-urls", 100, "Maximum number of URLs to crawl")
|
||||||
|
flag.IntVar(&globalTimeout, "global-timeout", 60, "Global timeout in seconds for the entire crawl")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if strings.TrimSpace(target) == "" {
|
||||||
|
fmt.Fprintln(os.Stderr, "-target is required")
|
||||||
|
flag.Usage()
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: timeout}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Report metadata
|
||||||
|
started := time.Now()
|
||||||
|
meta := report.Metadata{StartedAt: started.UTC().Format(time.RFC3339)}
|
||||||
|
params := report.Params{
|
||||||
|
MaxDepth: maxDepth,
|
||||||
|
Concurrency: concurrency,
|
||||||
|
TimeoutMs: timeout.Milliseconds(),
|
||||||
|
UserAgent: userAgent,
|
||||||
|
SameHostOnly: sameHostOnly,
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "Starting crawl of %s (depth: %d, max URLs: %d, timeout: %ds)...\n",
|
||||||
|
target, maxDepth, maxURLs, globalTimeout)
|
||||||
|
|
||||||
|
// Setup progress counters
|
||||||
|
var urlsVisited, urlsErrored atomic.Int64
|
||||||
|
var currentURL atomic.Value // string
|
||||||
|
var pendingTasks atomic.Int64
|
||||||
|
|
||||||
|
// Start progress reporter if not in quiet mode
|
||||||
|
ctxWithCancel, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if !quiet {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(500 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
cu, _ := currentURL.Load().(string)
|
||||||
|
fmt.Fprintf(os.Stderr, "\rURLs visited: %d | Errors: %d | Pending: %d | Current: %s",
|
||||||
|
urlsVisited.Load(), urlsErrored.Load(), pendingTasks.Load(), truncateForTTY(cu, 90))
|
||||||
|
case <-ctxWithCancel.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Progress callback functions
|
||||||
|
visitedCallback := func(u string, depth int, pending int) {
|
||||||
|
urlsVisited.Add(1)
|
||||||
|
pendingTasks.Store(int64(pending))
|
||||||
|
currentURL.Store(u)
|
||||||
|
}
|
||||||
|
errorCallback := func(u string, err error, pending int) {
|
||||||
|
urlsErrored.Add(1)
|
||||||
|
pendingTasks.Store(int64(pending))
|
||||||
|
currentURL.Store(u)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the safer crawler with timeout and URL limits
|
||||||
|
visited, crawlErrs, outlinks, pageInfo := crawlersafe.CrawlWithSafety(
|
||||||
|
ctx, target, maxDepth, concurrency, sameHostOnly,
|
||||||
|
client, userAgent, visitedCallback, errorCallback,
|
||||||
|
maxURLs, globalTimeout)
|
||||||
|
|
||||||
|
// Clear progress line before moving to next phase
|
||||||
|
if !quiet {
|
||||||
|
fmt.Fprintf(os.Stderr, "\rCrawl complete! URLs visited: %d | Errors: %d\n",
|
||||||
|
urlsVisited.Load(), urlsErrored.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "Fetching sitemap...\n")
|
||||||
|
smURLs, err := sitemap.FetchAll(ctx, target, client, userAgent)
|
||||||
|
if err != nil && !errors.Is(err, sitemap.ErrNotFound) {
|
||||||
|
fmt.Fprintf(os.Stderr, "sitemap error: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Robots.txt summary (simple)
|
||||||
|
robots := report.RobotsSummary{}
|
||||||
|
robotsURL := urlutil.Origin(target) + "/robots.txt"
|
||||||
|
{
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, robotsURL, nil)
|
||||||
|
req.Header.Set("User-Agent", userAgent)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err == nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
robots.Present = true
|
||||||
|
robots.FetchedAt = time.Now().UTC().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build set of all unique links discovered across pages for status checks
|
||||||
|
allLinks := make(map[string]struct{})
|
||||||
|
for _, m := range outlinks {
|
||||||
|
for u := range m {
|
||||||
|
allLinks[u] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Also include the visited pages themselves
|
||||||
|
for u := range visited {
|
||||||
|
allLinks[u] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "Checking %d links...\n", len(allLinks))
|
||||||
|
|
||||||
|
// Reset counters for link checking
|
||||||
|
urlsVisited.Store(0)
|
||||||
|
urlsErrored.Store(0)
|
||||||
|
|
||||||
|
// Progress callback functions for link checking
|
||||||
|
linkCheckCallback := func(ok bool) {
|
||||||
|
if ok {
|
||||||
|
urlsVisited.Add(1)
|
||||||
|
} else {
|
||||||
|
urlsErrored.Add(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
checkResults := linkcheck.Check(ctx, allLinks, concurrency, client, userAgent, !quiet, linkCheckCallback)
|
||||||
|
|
||||||
|
// Clear progress line before finishing
|
||||||
|
if !quiet {
|
||||||
|
fmt.Fprintf(os.Stderr, "\rLink checking complete! OK: %d | Errors: %d\n",
|
||||||
|
urlsVisited.Load(), urlsErrored.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
finished := time.Now()
|
||||||
|
meta.FinishedAt = finished.UTC().Format(time.RFC3339)
|
||||||
|
meta.DurationMs = finished.Sub(started).Milliseconds()
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "Building report...\n")
|
||||||
|
|
||||||
|
// Convert pageInfo to report.PageMeta
|
||||||
|
pages := make(map[string]report.PageMeta, len(pageInfo))
|
||||||
|
for u, pi := range pageInfo {
|
||||||
|
pages[u] = report.PageMeta{
|
||||||
|
Title: pi.Title,
|
||||||
|
ResponseTimeMs: pi.ResponseTimeMs,
|
||||||
|
ContentLength: pi.ContentLength,
|
||||||
|
Depth: pi.Depth,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reports := report.Build(target, visited, smURLs, crawlErrs, checkResults, outlinks, meta, params, pages, robots)
|
||||||
|
|
||||||
|
// Save report to a subdirectory named after the site
|
||||||
|
if err := saveReportToSiteDir(reports); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "save report error: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output the report
|
||||||
|
switch output {
|
||||||
|
case "json":
|
||||||
|
enc := json.NewEncoder(os.Stdout)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
_ = enc.Encode(reports)
|
||||||
|
default:
|
||||||
|
report.PrintText(os.Stdout, reports)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncateForTTY truncates s to max characters, replacing the tail with … if needed.
|
||||||
|
func truncateForTTY(s string, max int) string {
|
||||||
|
if max <= 0 || len(s) <= max {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
if max <= 1 {
|
||||||
|
return "…"
|
||||||
|
}
|
||||||
|
return s[:max-1] + "…"
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveReportToSiteDir saves the report to a subdirectory named after the site's hostname
|
||||||
|
// under the "reports" directory.
|
||||||
|
func saveReportToSiteDir(r report.Report) error {
|
||||||
|
u, err := url.Parse(r.Target)
|
||||||
|
if err != nil || u.Host == "" {
|
||||||
|
return fmt.Errorf("invalid target for save: %s", r.Target)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create base reports directory
|
||||||
|
reportsDir := "reports"
|
||||||
|
if err := os.MkdirAll(reportsDir, 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create subdirectory for this site
|
||||||
|
siteDir := filepath.Join(reportsDir, u.Host)
|
||||||
|
if err := os.MkdirAll(siteDir, 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save report to site subdirectory
|
||||||
|
path := filepath.Join(siteDir, "report.json")
|
||||||
|
f, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
enc := json.NewEncoder(f)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
if err := enc.Encode(r); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "Report saved to %s\n", path)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue