gosint-sitecrawl/safe_crawler.go

262 lines
7.0 KiB
Go

package main
import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync/atomic"
"time"
"urlcrawler/internal/crawlersafe"
"urlcrawler/internal/linkcheck"
"urlcrawler/internal/report"
"urlcrawler/internal/sitemap"
"urlcrawler/internal/urlutil"
)
func main() {
var target string
var concurrency int
var timeout time.Duration
var maxDepth int
var userAgent string
var sameHostOnly bool
var output string
var quiet bool
var maxURLs int
var globalTimeout int
flag.StringVar(&target, "target", "", "Target site URL (e.g., https://example.com)")
flag.IntVar(&concurrency, "concurrency", 5, "Number of concurrent workers")
flag.DurationVar(&timeout, "timeout", 10*time.Second, "HTTP timeout per request")
flag.IntVar(&maxDepth, "max-depth", 2, "Maximum crawl depth (0=crawl only the start page)")
flag.StringVar(&userAgent, "user-agent", "urlcrawler/1.0", "User-Agent header value")
flag.BoolVar(&sameHostOnly, "same-host-only", true, "Limit crawl to the same host as target")
flag.StringVar(&output, "output", "text", "Output format: text|json")
flag.BoolVar(&quiet, "quiet", false, "Suppress progress output")
flag.IntVar(&maxURLs, "max-urls", 100, "Maximum number of URLs to crawl")
flag.IntVar(&globalTimeout, "global-timeout", 60, "Global timeout in seconds for the entire crawl")
flag.Parse()
if strings.TrimSpace(target) == "" {
fmt.Fprintln(os.Stderr, "-target is required")
flag.Usage()
os.Exit(2)
}
client := &http.Client{Timeout: timeout}
ctx := context.Background()
// Report metadata
started := time.Now()
meta := report.Metadata{StartedAt: started.UTC().Format(time.RFC3339)}
params := report.Params{
MaxDepth: maxDepth,
Concurrency: concurrency,
TimeoutMs: timeout.Milliseconds(),
UserAgent: userAgent,
SameHostOnly: sameHostOnly,
}
fmt.Fprintf(os.Stderr, "Starting crawl of %s (depth: %d, max URLs: %d, timeout: %ds)...\n",
target, maxDepth, maxURLs, globalTimeout)
// Setup progress counters
var urlsVisited, urlsErrored atomic.Int64
var currentURL atomic.Value // string
var pendingTasks atomic.Int64
// Start progress reporter if not in quiet mode
ctxWithCancel, cancel := context.WithCancel(ctx)
defer cancel()
if !quiet {
go func() {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
cu, _ := currentURL.Load().(string)
fmt.Fprintf(os.Stderr, "\rURLs visited: %d | Errors: %d | Pending: %d | Current: %s",
urlsVisited.Load(), urlsErrored.Load(), pendingTasks.Load(), truncateForTTY(cu, 90))
case <-ctxWithCancel.Done():
return
}
}
}()
}
// Progress callback functions
visitedCallback := func(u string, depth int, pending int) {
urlsVisited.Add(1)
pendingTasks.Store(int64(pending))
currentURL.Store(u)
}
errorCallback := func(u string, err error, pending int) {
urlsErrored.Add(1)
pendingTasks.Store(int64(pending))
currentURL.Store(u)
}
// Use the safer crawler with timeout and URL limits
visited, crawlErrs, outlinks, pageInfo := crawlersafe.CrawlWithSafety(
ctx, target, maxDepth, concurrency, sameHostOnly,
client, userAgent, visitedCallback, errorCallback,
maxURLs, globalTimeout)
// Clear progress line before moving to next phase
if !quiet {
fmt.Fprintf(os.Stderr, "\rCrawl complete! URLs visited: %d | Errors: %d\n",
urlsVisited.Load(), urlsErrored.Load())
}
fmt.Fprintf(os.Stderr, "Fetching sitemap...\n")
smURLs, err := sitemap.FetchAll(ctx, target, client, userAgent)
if err != nil && !errors.Is(err, sitemap.ErrNotFound) {
fmt.Fprintf(os.Stderr, "sitemap error: %v\n", err)
}
// Robots.txt summary (simple)
robots := report.RobotsSummary{}
robotsURL := urlutil.Origin(target) + "/robots.txt"
{
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, robotsURL, nil)
req.Header.Set("User-Agent", userAgent)
resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
robots.Present = true
robots.FetchedAt = time.Now().UTC().Format(time.RFC3339)
}
}
}
// Build set of all unique links discovered across pages for status checks
allLinks := make(map[string]struct{})
for _, m := range outlinks {
for u := range m {
allLinks[u] = struct{}{}
}
}
// Also include the visited pages themselves
for u := range visited {
allLinks[u] = struct{}{}
}
fmt.Fprintf(os.Stderr, "Checking %d links...\n", len(allLinks))
// Reset counters for link checking
urlsVisited.Store(0)
urlsErrored.Store(0)
// Progress callback functions for link checking
linkCheckCallback := func(ok bool) {
if ok {
urlsVisited.Add(1)
} else {
urlsErrored.Add(1)
}
}
checkResults := linkcheck.Check(ctx, allLinks, concurrency, client, userAgent, !quiet, linkCheckCallback)
// Clear progress line before finishing
if !quiet {
fmt.Fprintf(os.Stderr, "\rLink checking complete! OK: %d | Errors: %d\n",
urlsVisited.Load(), urlsErrored.Load())
}
finished := time.Now()
meta.FinishedAt = finished.UTC().Format(time.RFC3339)
meta.DurationMs = finished.Sub(started).Milliseconds()
fmt.Fprintf(os.Stderr, "Building report...\n")
// Convert pageInfo to report.PageMeta
pages := make(map[string]report.PageMeta, len(pageInfo))
for u, pi := range pageInfo {
pages[u] = report.PageMeta{
Title: pi.Title,
ResponseTimeMs: pi.ResponseTimeMs,
ContentLength: pi.ContentLength,
Depth: pi.Depth,
}
}
reports := report.Build(target, visited, smURLs, crawlErrs, checkResults, outlinks, meta, params, pages, robots)
// Save report to a subdirectory named after the site
if err := saveReportToSiteDir(reports); err != nil {
fmt.Fprintf(os.Stderr, "save report error: %v\n", err)
}
// Output the report
switch output {
case "json":
enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", " ")
_ = enc.Encode(reports)
default:
report.PrintText(os.Stdout, reports)
}
}
// truncateForTTY truncates s to max characters, replacing the tail with … if needed.
func truncateForTTY(s string, max int) string {
if max <= 0 || len(s) <= max {
return s
}
if max <= 1 {
return "…"
}
return s[:max-1] + "…"
}
// saveReportToSiteDir saves the report to a subdirectory named after the site's hostname
// under the "reports" directory.
func saveReportToSiteDir(r report.Report) error {
u, err := url.Parse(r.Target)
if err != nil || u.Host == "" {
return fmt.Errorf("invalid target for save: %s", r.Target)
}
// Create base reports directory
reportsDir := "reports"
if err := os.MkdirAll(reportsDir, 0o755); err != nil {
return err
}
// Create subdirectory for this site
siteDir := filepath.Join(reportsDir, u.Host)
if err := os.MkdirAll(siteDir, 0o755); err != nil {
return err
}
// Save report to site subdirectory
path := filepath.Join(siteDir, "report.json")
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(r); err != nil {
return err
}
fmt.Fprintf(os.Stderr, "Report saved to %s\n", path)
return nil
}