gosint-sitecrawl/main.go

178 lines
4.7 KiB
Go

package main
import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"net/http"
"os"
"strings"
"sync/atomic"
"time"
"urlcrawler/internal/crawler"
"urlcrawler/internal/linkcheck"
"urlcrawler/internal/report"
"urlcrawler/internal/sitemap"
)
func main() {
var target string
var concurrency int
var timeout time.Duration
var maxDepth int
var userAgent string
var sameHostOnly bool
var output string
var quiet bool
flag.StringVar(&target, "target", "", "Target site URL (e.g., https://example.com)")
flag.IntVar(&concurrency, "concurrency", 10, "Number of concurrent workers")
flag.DurationVar(&timeout, "timeout", 10*time.Second, "HTTP timeout per request")
flag.IntVar(&maxDepth, "max-depth", 2, "Maximum crawl depth (0=crawl only the start page)")
flag.StringVar(&userAgent, "user-agent", "urlcrawler/1.0", "User-Agent header value")
flag.BoolVar(&sameHostOnly, "same-host-only", true, "Limit crawl to the same host as target")
flag.StringVar(&output, "output", "text", "Output format: text|json")
flag.BoolVar(&quiet, "quiet", false, "Suppress progress output")
flag.Parse()
if strings.TrimSpace(target) == "" {
fmt.Fprintln(os.Stderr, "-target is required")
flag.Usage()
os.Exit(2)
}
client := &http.Client{Timeout: timeout}
ctx := context.Background()
// Report metadata
started := time.Now()
meta := report.Metadata{StartedAt: started.UTC().Format(time.RFC3339)}
params := report.Params{
MaxDepth: maxDepth,
Concurrency: concurrency,
TimeoutMs: timeout.Milliseconds(),
UserAgent: userAgent,
SameHostOnly: sameHostOnly,
}
fmt.Fprintf(os.Stderr, "Starting crawl of %s (depth: %d)...\n", target, maxDepth)
// Setup progress counters
var urlsVisited, urlsErrored atomic.Int64
var currentURL atomic.Value // string
var pendingTasks atomic.Int64
// Start progress reporter if not in quiet mode
ctxWithCancel, cancel := context.WithCancel(ctx)
defer cancel()
if !quiet {
go func() {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
cu, _ := currentURL.Load().(string)
fmt.Fprintf(os.Stderr, "\rURLs visited: %d | Errors: %d | Pending: %d | Current: %s",
urlsVisited.Load(), urlsErrored.Load(), pendingTasks.Load(), truncateForTTY(cu, 90))
case <-ctxWithCancel.Done():
return
}
}
}()
}
// Progress callback functions
visitedCallback := func(u string, depth int, pending int) {
urlsVisited.Add(1)
pendingTasks.Store(int64(pending))
currentURL.Store(u)
}
errorCallback := func(u string, err error, pending int) {
urlsErrored.Add(1)
pendingTasks.Store(int64(pending))
currentURL.Store(u)
}
visited, crawlErrs, outlinks := crawler.Crawl(ctx, target, maxDepth, concurrency, sameHostOnly, client, userAgent, visitedCallback, errorCallback)
// Clear progress line before moving to next phase
if !quiet {
fmt.Fprintf(os.Stderr, "\rCrawl complete! URLs visited: %d | Errors: %d\n",
urlsVisited.Load(), urlsErrored.Load())
}
fmt.Fprintf(os.Stderr, "Fetching sitemap...\n")
smURLs, err := sitemap.FetchAll(ctx, target, client, userAgent)
if err != nil && !errors.Is(err, sitemap.ErrNotFound) {
fmt.Fprintf(os.Stderr, "sitemap error: %v\n", err)
}
// Build set of all unique links discovered across pages for status checks
allLinks := make(map[string]struct{})
for _, m := range outlinks {
for u := range m {
allLinks[u] = struct{}{}
}
}
// Also include the visited pages themselves
for u := range visited {
allLinks[u] = struct{}{}
}
fmt.Fprintf(os.Stderr, "Checking %d links...\n", len(allLinks))
// Reset counters for link checking
urlsVisited.Store(0)
urlsErrored.Store(0)
// Progress callback functions for link checking
linkCheckCallback := func(ok bool) {
if ok {
urlsVisited.Add(1)
} else {
urlsErrored.Add(1)
}
}
checkResults := linkcheck.Check(ctx, allLinks, concurrency, client, userAgent, !quiet, linkCheckCallback)
// Clear progress line before finishing
if !quiet {
fmt.Fprintf(os.Stderr, "\rLink checking complete! OK: %d | Errors: %d\n",
urlsVisited.Load(), urlsErrored.Load())
}
finished := time.Now()
meta.FinishedAt = finished.UTC().Format(time.RFC3339)
meta.DurationMs = finished.Sub(started).Milliseconds()
fmt.Fprintf(os.Stderr, "Building report...\n")
reports := report.Build(target, visited, smURLs, crawlErrs, checkResults, outlinks, meta, params)
switch output {
case "json":
enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", " ")
_ = enc.Encode(reports)
default:
report.PrintText(os.Stdout, reports)
}
}
// truncateForTTY truncates s to max characters, replacing the tail with … if needed.
func truncateForTTY(s string, max int) string {
if max <= 0 || len(s) <= max {
return s
}
if max <= 1 {
return "…"
}
return s[:max-1] + "…"
}