package crawler import ( "context" "io" "net/http" "sync" "urlcrawler/internal/htmlx" "urlcrawler/internal/urlutil" ) type task struct { url string depth int } // Crawl visits pages up to maxDepth and returns visited set, per-URL errors, and per-page outgoing links. // The visitedCallback and errorCallback functions are called when a page is successfully visited or encounters an error. // visitedCallback receives the URL, its depth, and the current number of pending tasks in the queue. // errorCallback receives the URL, the error, and the current number of pending tasks in the queue. func Crawl(ctx context.Context, startURL string, maxDepth int, concurrency int, sameHostOnly bool, client *http.Client, userAgent string, visitedCallback func(string, int, int), errorCallback func(string, error, int)) (map[string]struct{}, map[string]error, map[string]map[string]struct{}) { visited := make(map[string]struct{}) errs := make(map[string]error) outlinks := make(map[string]map[string]struct{}) var mu sync.Mutex origin := urlutil.Origin(startURL) tasks := make(chan task, concurrency*2) wgWorkers := sync.WaitGroup{} wgTasks := sync.WaitGroup{} enqueue := func(t task) { wgTasks.Add(1) tasks <- t } worker := func() { defer wgWorkers.Done() for tk := range tasks { if ctx.Err() != nil { wgTasks.Done() return } mu.Lock() if _, seen := visited[tk.url]; seen { mu.Unlock() wgTasks.Done() continue } visited[tk.url] = struct{}{} mu.Unlock() if visitedCallback != nil { visitedCallback(tk.url, tk.depth, len(tasks)) } req, _ := http.NewRequestWithContext(ctx, http.MethodGet, tk.url, nil) req.Header.Set("User-Agent", userAgent) resp, err := client.Do(req) if err != nil { mu.Lock() errs[tk.url] = err mu.Unlock() if errorCallback != nil { errorCallback(tk.url, err, len(tasks)) } wgTasks.Done() continue } func() { defer resp.Body.Close() ct := resp.Header.Get("Content-Type") if resp.StatusCode != http.StatusOK || ct == "" || (ct != "text/html" && !hasPrefix(ct, "text/html")) { return } body, _ := io.ReadAll(resp.Body) hrefs := htmlx.ExtractAnchors(stringsReader(string(body))) var toEnqueue []string for _, href := range hrefs { abs, ok := urlutil.Normalize(tk.url, href) if !ok { continue } mu.Lock() m, ok2 := outlinks[tk.url] if !ok2 { m = make(map[string]struct{}) outlinks[tk.url] = m } m[abs] = struct{}{} mu.Unlock() if tk.depth < maxDepth { if !sameHostOnly || urlutil.SameHost(origin, abs) { toEnqueue = append(toEnqueue, abs) } } } for _, u := range toEnqueue { enqueue(task{url: u, depth: tk.depth + 1}) } }() wgTasks.Done() } } for i := 0; i < concurrency; i++ { wgWorkers.Add(1) go worker() } // Close the tasks channel when all enqueued tasks are processed. go func() { wgTasks.Wait() close(tasks) }() enqueue(task{url: startURL, depth: 0}) wgWorkers.Wait() return visited, errs, outlinks } func hasPrefix(s string, prefix string) bool { return len(s) >= len(prefix) && s[:len(prefix)] == prefix } // stringsReader avoids importing strings at package top for a single use. func stringsReader(s string) io.Reader { return &stringReader{str: s} } type stringReader struct{ str string } func (r *stringReader) Read(p []byte) (int, error) { if len(r.str) == 0 { return 0, io.EOF } n := copy(p, r.str) r.str = r.str[n:] return n, nil }