package crawlersafe import ( "context" "io" "net/http" "sync" "time" "urlcrawler/internal/htmlx" "urlcrawler/internal/urlutil" ) type task struct { url string depth int } type PageInfo struct { Title string ResponseTimeMs int64 ContentLength int Depth int } // CrawlWithSafety is a safer version of the crawler that adds a global timeout // and limits the maximum number of URLs to process to prevent infinite loops. func CrawlWithSafety(ctx context.Context, startURL string, maxDepth int, concurrency int, sameHostOnly bool, client *http.Client, userAgent string, visitedCallback func(string, int, int), errorCallback func(string, error, int), maxURLs int, timeoutSeconds int) (map[string]struct{}, map[string]error, map[string]map[string]struct{}, map[string]PageInfo) { // Create a context with timeout ctxWithTimeout, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) defer cancel() visited := make(map[string]struct{}) errs := make(map[string]error) outlinks := make(map[string]map[string]struct{}) pageInfos := make(map[string]PageInfo) var mu sync.Mutex var urlCounter int origin := urlutil.Origin(startURL) tasks := make(chan task, concurrency*2) wgWorkers := sync.WaitGroup{} wgTasks := sync.WaitGroup{} enqueue := func(t task) { mu.Lock() defer mu.Unlock() // Check if we've reached the URL limit if urlCounter >= maxURLs { return } urlCounter++ wgTasks.Add(1) select { case tasks <- t: // Successfully enqueued case <-ctxWithTimeout.Done(): // Context canceled or timed out wgTasks.Done() } } worker := func() { defer wgWorkers.Done() for { select { case <-ctxWithTimeout.Done(): return case tk, ok := <-tasks: if !ok { return } if ctxWithTimeout.Err() != nil { wgTasks.Done() return } mu.Lock() if _, seen := visited[tk.url]; seen { mu.Unlock() wgTasks.Done() continue } visited[tk.url] = struct{}{} mu.Unlock() if visitedCallback != nil { visitedCallback(tk.url, tk.depth, len(tasks)) } // Create a context with timeout for this specific request reqCtx, reqCancel := context.WithTimeout(ctxWithTimeout, 10*time.Second) start := time.Now() req, _ := http.NewRequestWithContext(reqCtx, http.MethodGet, tk.url, nil) req.Header.Set("User-Agent", userAgent) resp, err := client.Do(req) reqCancel() // Cancel the request context if err != nil { mu.Lock() errs[tk.url] = err pageInfos[tk.url] = PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth} mu.Unlock() if errorCallback != nil { errorCallback(tk.url, err, len(tasks)) } wgTasks.Done() continue } func() { defer resp.Body.Close() ct := resp.Header.Get("Content-Type") // Default meta values meta := PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth} if resp.ContentLength > 0 { meta.ContentLength = int(resp.ContentLength) } if resp.StatusCode != http.StatusOK || ct == "" || (ct != "text/html" && !hasPrefix(ct, "text/html")) { mu.Lock() pageInfos[tk.url] = meta mu.Unlock() return } // Limit body read to prevent memory issues body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) // 1MB limit meta.ContentLength = len(body) meta.Title = htmlx.ExtractTitle(stringsReader(string(body))) hrefs := htmlx.ExtractAnchors(stringsReader(string(body))) var toEnqueue []string for _, href := range hrefs { abs, ok := urlutil.Normalize(tk.url, href) if !ok { continue } mu.Lock() m, ok2 := outlinks[tk.url] if !ok2 { m = make(map[string]struct{}) outlinks[tk.url] = m } m[abs] = struct{}{} mu.Unlock() if tk.depth < maxDepth { if !sameHostOnly || urlutil.SameHost(origin, abs) { toEnqueue = append(toEnqueue, abs) } } } // Limit the number of links to follow from a single page if len(toEnqueue) > 50 { toEnqueue = toEnqueue[:50] } for _, u := range toEnqueue { enqueue(task{url: u, depth: tk.depth + 1}) } mu.Lock() pageInfos[tk.url] = meta mu.Unlock() }() wgTasks.Done() } } } for i := 0; i < concurrency; i++ { wgWorkers.Add(1) go worker() } // Close the tasks channel when all enqueued tasks are processed or timeout occurs go func() { done := make(chan struct{}) go func() { wgTasks.Wait() close(done) }() select { case <-done: // All tasks completed case <-ctxWithTimeout.Done(): // Timeout occurred } close(tasks) }() enqueue(task{url: startURL, depth: 0}) // Wait for workers to finish or timeout workersDone := make(chan struct{}) go func() { wgWorkers.Wait() close(workersDone) }() select { case <-workersDone: // Workers finished normally case <-ctxWithTimeout.Done(): // Timeout occurred } return visited, errs, outlinks, pageInfos } func hasPrefix(s string, prefix string) bool { return len(s) >= len(prefix) && s[:len(prefix)] == prefix } // stringsReader avoids importing strings at package top for a single use. func stringsReader(s string) io.Reader { return &stringReader{str: s} } type stringReader struct{ str string } func (r *stringReader) Read(p []byte) (int, error) { if len(r.str) == 0 { return 0, io.EOF } n := copy(p, r.str) r.str = r.str[n:] return n, nil }