gosint-sitecrawl/internal/crawlersafe/crawlersafe.go

241 lines
5.6 KiB
Go

package crawlersafe
import (
"context"
"io"
"net/http"
"sync"
"time"
"urlcrawler/internal/htmlx"
"urlcrawler/internal/urlutil"
)
type task struct {
url string
depth int
}
type PageInfo struct {
Title string
ResponseTimeMs int64
ContentLength int
Depth int
}
// CrawlWithSafety is a safer version of the crawler that adds a global timeout
// and limits the maximum number of URLs to process to prevent infinite loops.
func CrawlWithSafety(ctx context.Context, startURL string, maxDepth int, concurrency int,
sameHostOnly bool, client *http.Client, userAgent string,
visitedCallback func(string, int, int), errorCallback func(string, error, int),
maxURLs int, timeoutSeconds int) (map[string]struct{}, map[string]error, map[string]map[string]struct{}, map[string]PageInfo) {
// Create a context with timeout
ctxWithTimeout, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
defer cancel()
visited := make(map[string]struct{})
errs := make(map[string]error)
outlinks := make(map[string]map[string]struct{})
pageInfos := make(map[string]PageInfo)
var mu sync.Mutex
var urlCounter int
origin := urlutil.Origin(startURL)
tasks := make(chan task, concurrency*2)
wgWorkers := sync.WaitGroup{}
wgTasks := sync.WaitGroup{}
enqueue := func(t task) {
mu.Lock()
defer mu.Unlock()
// Check if we've reached the URL limit
if urlCounter >= maxURLs {
return
}
urlCounter++
wgTasks.Add(1)
select {
case tasks <- t:
// Successfully enqueued
case <-ctxWithTimeout.Done():
// Context canceled or timed out
wgTasks.Done()
}
}
worker := func() {
defer wgWorkers.Done()
for {
select {
case <-ctxWithTimeout.Done():
return
case tk, ok := <-tasks:
if !ok {
return
}
if ctxWithTimeout.Err() != nil {
wgTasks.Done()
return
}
mu.Lock()
if _, seen := visited[tk.url]; seen {
mu.Unlock()
wgTasks.Done()
continue
}
visited[tk.url] = struct{}{}
mu.Unlock()
if visitedCallback != nil {
visitedCallback(tk.url, tk.depth, len(tasks))
}
// Create a context with timeout for this specific request
reqCtx, reqCancel := context.WithTimeout(ctxWithTimeout, 10*time.Second)
start := time.Now()
req, _ := http.NewRequestWithContext(reqCtx, http.MethodGet, tk.url, nil)
req.Header.Set("User-Agent", userAgent)
resp, err := client.Do(req)
reqCancel() // Cancel the request context
if err != nil {
mu.Lock()
errs[tk.url] = err
pageInfos[tk.url] = PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth}
mu.Unlock()
if errorCallback != nil {
errorCallback(tk.url, err, len(tasks))
}
wgTasks.Done()
continue
}
func() {
defer resp.Body.Close()
ct := resp.Header.Get("Content-Type")
// Default meta values
meta := PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth}
if resp.ContentLength > 0 {
meta.ContentLength = int(resp.ContentLength)
}
if resp.StatusCode != http.StatusOK || ct == "" || (ct != "text/html" && !hasPrefix(ct, "text/html")) {
mu.Lock()
pageInfos[tk.url] = meta
mu.Unlock()
return
}
// Limit body read to prevent memory issues
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) // 1MB limit
meta.ContentLength = len(body)
meta.Title = htmlx.ExtractTitle(stringsReader(string(body)))
hrefs := htmlx.ExtractAnchors(stringsReader(string(body)))
var toEnqueue []string
for _, href := range hrefs {
abs, ok := urlutil.Normalize(tk.url, href)
if !ok {
continue
}
mu.Lock()
m, ok2 := outlinks[tk.url]
if !ok2 {
m = make(map[string]struct{})
outlinks[tk.url] = m
}
m[abs] = struct{}{}
mu.Unlock()
if tk.depth < maxDepth {
if !sameHostOnly || urlutil.SameHost(origin, abs) {
toEnqueue = append(toEnqueue, abs)
}
}
}
// Limit the number of links to follow from a single page
if len(toEnqueue) > 50 {
toEnqueue = toEnqueue[:50]
}
for _, u := range toEnqueue {
enqueue(task{url: u, depth: tk.depth + 1})
}
mu.Lock()
pageInfos[tk.url] = meta
mu.Unlock()
}()
wgTasks.Done()
}
}
}
for i := 0; i < concurrency; i++ {
wgWorkers.Add(1)
go worker()
}
// Close the tasks channel when all enqueued tasks are processed or timeout occurs
go func() {
done := make(chan struct{})
go func() {
wgTasks.Wait()
close(done)
}()
select {
case <-done:
// All tasks completed
case <-ctxWithTimeout.Done():
// Timeout occurred
}
close(tasks)
}()
enqueue(task{url: startURL, depth: 0})
// Wait for workers to finish or timeout
workersDone := make(chan struct{})
go func() {
wgWorkers.Wait()
close(workersDone)
}()
select {
case <-workersDone:
// Workers finished normally
case <-ctxWithTimeout.Done():
// Timeout occurred
}
return visited, errs, outlinks, pageInfos
}
func hasPrefix(s string, prefix string) bool {
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
}
// stringsReader avoids importing strings at package top for a single use.
func stringsReader(s string) io.Reader {
return &stringReader{str: s}
}
type stringReader struct{ str string }
func (r *stringReader) Read(p []byte) (int, error) {
if len(r.str) == 0 {
return 0, io.EOF
}
n := copy(p, r.str)
r.str = r.str[n:]
return n, nil
}