technical-screen-2025-10-22/modules/rag_agent.py

287 lines
10 KiB
Python

#!/usr/bin/env python3
import re
import json
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from .client import get_openrouter_client
@dataclass
class MarketCapClaim:
"""Represents a market cap claim found in slide text"""
slide_number: int
company_name: str
claimed_market_cap: str
raw_text: str
confidence: float
@dataclass
class ValidationResult:
"""Represents the validation result for a market cap claim"""
claim: MarketCapClaim
validated_market_cap: Optional[str]
validation_source: str
confidence_score: float
is_accurate: bool
discrepancy: Optional[str]
rag_search_query: str
rag_response: str
class MarketCapRAGAgent:
"""
RAG Agent for validating market cap claims from pitch deck slides
using OpenRouter's web search capabilities
"""
def __init__(self, api_key: Optional[str] = None):
self.client = get_openrouter_client()
self.market_cap_patterns = [
r'market\s+cap(?:italization)?\s*:?\s*\$?([0-9,.]+[BMK]?)',
r'valuation\s*:?\s*\$?([0-9,.]+[BMK]?)',
r'worth\s*:?\s*\$?([0-9,.]+[BMK]?)',
r'valued\s+at\s*:?\s*\$?([0-9,.]+[BMK]?)',
r'\$([0-9,.]+[BMK]?)\s+(?:market\s+cap|valuation)',
r'(?:market\s+cap|valuation)\s+of\s+\$?([0-9,.]+[BMK]?)'
]
def extract_market_cap_claims(self, slide_texts: List[Dict[str, Any]]) -> List[MarketCapClaim]:
"""
Extract market cap claims from slide text exports
Args:
slide_texts: List of slide data with 'slide_number' and 'text' keys
Returns:
List of MarketCapClaim objects
"""
claims = []
for slide_data in slide_texts:
slide_number = slide_data.get('slide_number', 0)
text = slide_data.get('text', '')
if not text:
continue
# Extract company name (usually in first few lines or title)
company_name = self._extract_company_name(text)
# Search for market cap patterns
for pattern in self.market_cap_patterns:
matches = re.finditer(pattern, text, re.IGNORECASE | re.MULTILINE)
for match in matches:
claimed_value = match.group(1)
raw_text = match.group(0)
# Calculate confidence based on context
confidence = self._calculate_confidence(text, match.start(), match.end())
claim = MarketCapClaim(
slide_number=slide_number,
company_name=company_name,
claimed_market_cap=claimed_value,
raw_text=raw_text,
confidence=confidence
)
claims.append(claim)
return claims
def _extract_company_name(self, text: str) -> str:
"""Extract company name from slide text"""
lines = text.split('\n')[:5] # Check first 5 lines
for line in lines:
line = line.strip()
if line and len(line) > 2 and len(line) < 100:
# Skip common slide headers
if not any(header in line.lower() for header in ['slide', 'page', 'agenda', 'overview']):
return line
return "Unknown Company"
def _calculate_confidence(self, text: str, start: int, end: int) -> float:
"""Calculate confidence score for a market cap claim"""
confidence = 0.5 # Base confidence
# Extract context around the match
context_start = max(0, start - 50)
context_end = min(len(text), end + 50)
context = text[context_start:context_end].lower()
# Increase confidence for specific indicators
if any(indicator in context for indicator in ['current', 'latest', 'as of', '2024', '2025']):
confidence += 0.2
if any(indicator in context for indicator in ['billion', 'million', 'trillion']):
confidence += 0.1
if 'market cap' in context or 'valuation' in context:
confidence += 0.2
return min(confidence, 1.0)
def validate_claim_with_rag(self, claim: MarketCapClaim) -> ValidationResult:
"""
Validate a market cap claim using RAG search
Args:
claim: MarketCapClaim to validate
Returns:
ValidationResult with validation details
"""
# Construct RAG search query
search_query = f"{claim.company_name} current market cap valuation 2024 2025"
try:
# Use OpenRouter with online search enabled
response = self.client.chat.completions.create(
model="mistralai/mistral-small",
messages=[
{
"role": "user",
"content": f"""
Please search for the current market cap or valuation of {claim.company_name}.
The company claims their market cap is ${claim.claimed_market_cap}.
Please provide:
1. The current market cap/valuation if found
2. The source of this information
3. Whether the claimed value appears accurate
4. Any significant discrepancies
Focus on recent data from 2024-2025.
"""
}
],
max_tokens=800
)
rag_response = response.choices[0].message.content.strip()
# Parse the response to extract validation details
validation_details = self._parse_rag_response(rag_response, claim)
return ValidationResult(
claim=claim,
validated_market_cap=validation_details.get('validated_cap'),
validation_source=validation_details.get('source', 'RAG Search'),
confidence_score=validation_details.get('confidence', 0.5),
is_accurate=validation_details.get('is_accurate', False),
discrepancy=validation_details.get('discrepancy'),
rag_search_query=search_query,
rag_response=rag_response
)
except Exception as e:
return ValidationResult(
claim=claim,
validated_market_cap=None,
validation_source="Error",
confidence_score=0.0,
is_accurate=False,
discrepancy=f"RAG search failed: {str(e)}",
rag_search_query=search_query,
rag_response=f"Error: {str(e)}"
)
def _parse_rag_response(self, response: str, claim: MarketCapClaim) -> Dict[str, Any]:
"""Parse RAG response to extract validation details"""
details = {
'validated_cap': None,
'source': 'RAG Search',
'confidence': 0.5,
'is_accurate': False,
'discrepancy': None
}
response_lower = response.lower()
# Look for market cap values in the response
cap_patterns = [
r'\$([0-9,.]+[BMK]?)',
r'([0-9,.]+[BMK]?)\s+(?:billion|million|trillion)',
r'market\s+cap(?:italization)?\s*:?\s*\$?([0-9,.]+[BMK]?)'
]
for pattern in cap_patterns:
matches = re.findall(pattern, response_lower)
if matches:
details['validated_cap'] = matches[0]
break
# Determine accuracy
if details['validated_cap']:
claimed_normalized = self._normalize_value(claim.claimed_market_cap)
validated_normalized = self._normalize_value(details['validated_cap'])
if claimed_normalized and validated_normalized:
# Allow for some variance (within 20%)
ratio = min(claimed_normalized, validated_normalized) / max(claimed_normalized, validated_normalized)
details['is_accurate'] = ratio > 0.8
if not details['is_accurate']:
details['discrepancy'] = f"Claimed: ${claim.claimed_market_cap}, Found: ${details['validated_cap']}"
# Extract source information
if 'source:' in response_lower or 'according to' in response_lower:
source_match = re.search(r'(?:source:|according to)\s*([^\n]+)', response_lower)
if source_match:
details['source'] = source_match.group(1).strip()
return details
def _normalize_value(self, value: str) -> Optional[float]:
"""Normalize market cap value to a comparable number"""
if not value:
return None
value = value.replace(',', '').upper()
multiplier = 1
if value.endswith('B'):
multiplier = 1_000_000_000
value = value[:-1]
elif value.endswith('M'):
multiplier = 1_000_000
value = value[:-1]
elif value.endswith('K'):
multiplier = 1_000
value = value[:-1]
elif value.endswith('T'):
multiplier = 1_000_000_000_000
value = value[:-1]
try:
return float(value) * multiplier
except ValueError:
return None
def validate_all_claims(self, slide_texts: List[Dict[str, Any]]) -> List[ValidationResult]:
"""
Extract and validate all market cap claims from slide texts
Args:
slide_texts: List of slide data with 'slide_number' and 'text' keys
Returns:
List of ValidationResult objects
"""
claims = self.extract_market_cap_claims(slide_texts)
results = []
print(f"Found {len(claims)} market cap claims to validate...")
for i, claim in enumerate(claims, 1):
print(f" Validating claim {i}/{len(claims)}: {claim.company_name} - ${claim.claimed_market_cap}")
result = self.validate_claim_with_rag(claim)
results.append(result)
return results