170 lines
4.3 KiB
Go
170 lines
4.3 KiB
Go
package crawler
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"urlcrawler/internal/htmlx"
|
|
"urlcrawler/internal/urlutil"
|
|
)
|
|
|
|
type task struct {
|
|
url string
|
|
depth int
|
|
}
|
|
|
|
type PageInfo struct {
|
|
Title string
|
|
ResponseTimeMs int64
|
|
ContentLength int
|
|
Depth int
|
|
}
|
|
|
|
// Crawl visits pages up to maxDepth and returns visited set, per-URL errors, and per-page outgoing links.
|
|
// The visitedCallback and errorCallback functions are called when a page is successfully visited or encounters an error.
|
|
// visitedCallback receives the URL, its depth, and the current number of pending tasks in the queue.
|
|
// errorCallback receives the URL, the error, and the current number of pending tasks in the queue.
|
|
func Crawl(ctx context.Context, startURL string, maxDepth int, concurrency int, sameHostOnly bool, client *http.Client, userAgent string, visitedCallback func(string, int, int), errorCallback func(string, error, int)) (map[string]struct{}, map[string]error, map[string]map[string]struct{}, map[string]PageInfo) {
|
|
visited := make(map[string]struct{})
|
|
errs := make(map[string]error)
|
|
outlinks := make(map[string]map[string]struct{})
|
|
pageInfos := make(map[string]PageInfo)
|
|
var mu sync.Mutex
|
|
|
|
origin := urlutil.Origin(startURL)
|
|
|
|
tasks := make(chan task, concurrency*2)
|
|
wgWorkers := sync.WaitGroup{}
|
|
wgTasks := sync.WaitGroup{}
|
|
|
|
enqueue := func(t task) {
|
|
wgTasks.Add(1)
|
|
tasks <- t
|
|
}
|
|
|
|
worker := func() {
|
|
defer wgWorkers.Done()
|
|
for tk := range tasks {
|
|
if ctx.Err() != nil {
|
|
wgTasks.Done()
|
|
return
|
|
}
|
|
mu.Lock()
|
|
if _, seen := visited[tk.url]; seen {
|
|
mu.Unlock()
|
|
wgTasks.Done()
|
|
continue
|
|
}
|
|
visited[tk.url] = struct{}{}
|
|
mu.Unlock()
|
|
|
|
if visitedCallback != nil {
|
|
visitedCallback(tk.url, tk.depth, len(tasks))
|
|
}
|
|
|
|
start := time.Now()
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, tk.url, nil)
|
|
req.Header.Set("User-Agent", userAgent)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
mu.Lock()
|
|
errs[tk.url] = err
|
|
pageInfos[tk.url] = PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth}
|
|
mu.Unlock()
|
|
|
|
if errorCallback != nil {
|
|
errorCallback(tk.url, err, len(tasks))
|
|
}
|
|
wgTasks.Done()
|
|
continue
|
|
}
|
|
func() {
|
|
defer resp.Body.Close()
|
|
ct := resp.Header.Get("Content-Type")
|
|
// Default meta values
|
|
meta := PageInfo{Title: "", ResponseTimeMs: time.Since(start).Milliseconds(), ContentLength: 0, Depth: tk.depth}
|
|
if resp.ContentLength > 0 {
|
|
meta.ContentLength = int(resp.ContentLength)
|
|
}
|
|
if resp.StatusCode != http.StatusOK || ct == "" || (ct != "text/html" && !hasPrefix(ct, "text/html")) {
|
|
mu.Lock()
|
|
pageInfos[tk.url] = meta
|
|
mu.Unlock()
|
|
return
|
|
}
|
|
body, _ := io.ReadAll(resp.Body)
|
|
meta.ContentLength = len(body)
|
|
meta.Title = htmlx.ExtractTitle(stringsReader(string(body)))
|
|
hrefs := htmlx.ExtractAnchors(stringsReader(string(body)))
|
|
var toEnqueue []string
|
|
for _, href := range hrefs {
|
|
abs, ok := urlutil.Normalize(tk.url, href)
|
|
if !ok {
|
|
continue
|
|
}
|
|
mu.Lock()
|
|
m, ok2 := outlinks[tk.url]
|
|
if !ok2 {
|
|
m = make(map[string]struct{})
|
|
outlinks[tk.url] = m
|
|
}
|
|
m[abs] = struct{}{}
|
|
mu.Unlock()
|
|
|
|
if tk.depth < maxDepth {
|
|
if !sameHostOnly || urlutil.SameHost(origin, abs) {
|
|
toEnqueue = append(toEnqueue, abs)
|
|
}
|
|
}
|
|
}
|
|
for _, u := range toEnqueue {
|
|
enqueue(task{url: u, depth: tk.depth + 1})
|
|
}
|
|
mu.Lock()
|
|
pageInfos[tk.url] = meta
|
|
mu.Unlock()
|
|
}()
|
|
wgTasks.Done()
|
|
}
|
|
}
|
|
|
|
for i := 0; i < concurrency; i++ {
|
|
wgWorkers.Add(1)
|
|
go worker()
|
|
}
|
|
|
|
// Close the tasks channel when all enqueued tasks are processed.
|
|
go func() {
|
|
wgTasks.Wait()
|
|
close(tasks)
|
|
}()
|
|
|
|
enqueue(task{url: startURL, depth: 0})
|
|
wgWorkers.Wait()
|
|
|
|
return visited, errs, outlinks, pageInfos
|
|
}
|
|
|
|
func hasPrefix(s string, prefix string) bool {
|
|
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
|
|
}
|
|
|
|
// stringsReader avoids importing strings at package top for a single use.
|
|
func stringsReader(s string) io.Reader {
|
|
return &stringReader{str: s}
|
|
}
|
|
|
|
type stringReader struct{ str string }
|
|
|
|
func (r *stringReader) Read(p []byte) (int, error) {
|
|
if len(r.str) == 0 {
|
|
return 0, io.EOF
|
|
}
|
|
n := copy(p, r.str)
|
|
r.str = r.str[n:]
|
|
return n, nil
|
|
}
|