TradingAgents/tradingagents/validation/semantic_fact_checker.py

596 lines
20 KiB
Python

"""
Production Semantic Fact Checker with NLI
Features:
- DeBERTa-based entailment checking
- Targeted validation (final arguments only, not full conversation)
- Hash-based caching to prevent redundant checks
- Catches semantic contradictions ("fell" vs "rose")
"""
from typing import Dict, Any, List, Optional
import hashlib
import json
from dataclasses import dataclass
from enum import Enum
import re
class EntailmentLabel(Enum):
"""NLI entailment labels."""
ENTAILMENT = "entailment"
CONTRADICTION = "contradiction"
NEUTRAL = "neutral"
@dataclass
class FactCheckResult:
"""Result of fact checking."""
valid: bool
label: EntailmentLabel
confidence: float
evidence: str
cached: bool = False
class SemanticFactChecker:
"""
Validate claims using NLI (Natural Language Inference).
CRITICAL OPTIMIZATIONS:
1. Targeted validation: Only check final arguments, not full conversation
2. Caching: Hash claims and cache results per trading day
3. Batch processing: Check multiple claims in one NLI call
"""
def __init__(
self,
model_name: str = "microsoft/deberta-v3-small",
use_local_model: bool = True,
cache_size: int = 10000
):
"""
Initialize fact checker.
Args:
model_name: HuggingFace NLI model
use_local_model: Try to load local model, fallback to LLM
cache_size: Maximum cache entries
"""
self.use_local_model = use_local_model
self.nli_pipeline = None
self.llm = None
# Cache: {claim_hash: FactCheckResult}
self.cache = {}
self.cache_size = cache_size
# Try to load NLI model
if use_local_model:
try:
from transformers import pipeline
import torch
self.nli_pipeline = pipeline(
"text-classification",
model=model_name,
device=0 if torch.cuda.is_available() else -1
)
print(f"✅ Loaded NLI model: {model_name}")
except Exception as e:
print(f"⚠️ Could not load NLI model: {e}")
print(" Falling back to LLM-based validation")
self.use_local_model = False
def set_llm(self, llm):
"""Set LLM for fallback validation."""
self.llm = llm
def validate_arguments(
self,
arguments: List[str],
ground_truth: Dict[str, Any],
trading_date: str
) -> Dict[str, FactCheckResult]:
"""
Validate a list of arguments against ground truth.
TARGETED VALIDATION: Only validates final arguments, not full conversation.
Args:
arguments: List of claims to validate (from JSON "key_arguments")
ground_truth: Structured ground truth data
trading_date: Date for cache scoping
Returns:
Dict mapping argument to FactCheckResult
"""
results = {}
for argument in arguments:
# Check cache first
cache_key = self._get_cache_key(argument, trading_date)
if cache_key in self.cache:
result = self.cache[cache_key]
result.cached = True
results[argument] = result
continue
# Validate uncached argument
result = self._validate_single_argument(argument, ground_truth)
# Cache result
self._add_to_cache(cache_key, result)
results[argument] = result
return results
def _validate_single_argument(
self,
argument: str,
ground_truth: Dict[str, Any]
) -> FactCheckResult:
"""
Validate a single argument.
Args:
argument: Claim to validate
ground_truth: Ground truth data
Returns:
FactCheckResult
"""
# Classify argument type
arg_type = self._classify_argument(argument)
if arg_type == "revenue":
return self._validate_revenue_claim(argument, ground_truth)
elif arg_type == "price":
return self._validate_price_claim(argument, ground_truth)
elif arg_type == "technical":
return self._validate_technical_claim(argument, ground_truth)
else:
# Cannot validate qualitative claims
return FactCheckResult(
valid=True, # Assume valid if can't verify
label=EntailmentLabel.NEUTRAL,
confidence=0.5,
evidence="Qualitative claim - cannot verify"
)
def _validate_revenue_claim(
self,
claim: str,
ground_truth: Dict[str, Any]
) -> FactCheckResult:
"""
Validate revenue-related claim using NLI.
Example:
Claim: "Revenue fell 5%"
Truth: revenue_growth_yoy = 0.05 (grew 5%)
Result: CONTRADICTION
"""
# Extract ground truth
revenue_growth = ground_truth.get("revenue_growth_yoy")
if revenue_growth is None:
return FactCheckResult(
valid=True,
label=EntailmentLabel.NEUTRAL,
confidence=0.0,
evidence="No revenue data available"
)
# Construct premise from ground truth
if revenue_growth > 0:
premise = f"Revenue increased by {abs(revenue_growth):.1%} year-over-year."
elif revenue_growth < 0:
premise = f"Revenue decreased by {abs(revenue_growth):.1%} year-over-year."
else:
premise = "Revenue remained flat year-over-year."
# Check entailment
return self._check_entailment(premise, claim)
def _validate_price_claim(
self,
claim: str,
ground_truth: Dict[str, Any]
) -> FactCheckResult:
"""Validate price movement claim."""
price_change = ground_truth.get("price_change_pct")
if price_change is None:
return FactCheckResult(
valid=True,
label=EntailmentLabel.NEUTRAL,
confidence=0.0,
evidence="No price data available"
)
# Construct premise
if price_change > 0:
premise = f"Price increased by {abs(price_change):.1%}."
elif price_change < 0:
premise = f"Price decreased by {abs(price_change):.1%}."
else:
premise = "Price remained unchanged."
return self._check_entailment(premise, claim)
def _validate_technical_claim(
self,
claim: str,
ground_truth: Dict[str, Any]
) -> FactCheckResult:
"""Validate technical indicator claim (simple numeric check)."""
# For technical indicators, use simple numeric comparison
# Extract number from claim
import re
claim_numbers = re.findall(r'\d+(?:\.\d+)?', claim)
if not claim_numbers:
return FactCheckResult(
valid=True,
label=EntailmentLabel.NEUTRAL,
confidence=0.5,
evidence="No numbers in claim"
)
# Check if RSI/MACD values match ground truth
indicators = ground_truth.get("indicators", {})
# Simple heuristic: if claim mentions RSI and ground truth has RSI, compare
if "rsi" in claim.lower() and "RSI" in indicators:
claim_val = float(claim_numbers[0])
truth_val = indicators["RSI"]
if abs(claim_val - truth_val) < 2.0: # Within 2 points
return FactCheckResult(
valid=True,
label=EntailmentLabel.ENTAILMENT,
confidence=0.9,
evidence=f"RSI values match: {claim_val}{truth_val}"
)
else:
return FactCheckResult(
valid=False,
label=EntailmentLabel.CONTRADICTION,
confidence=0.8,
evidence=f"RSI mismatch: claimed {claim_val}, actual {truth_val}"
)
return FactCheckResult(
valid=True,
label=EntailmentLabel.NEUTRAL,
confidence=0.5,
evidence="Cannot verify technical claim"
)
def _check_entailment(
self,
premise: str,
hypothesis: str
) -> FactCheckResult:
"""
Check if premise entails hypothesis using HYBRID VALIDATION.
LAYER 1: Numeric Hard-Check (Sanity Layer)
- Extract all % and $ values
- If divergence > 10%, reject immediately
- Do NOT let LLM decide if 500 equals 8
LAYER 2: DeBERTa NLI Model (Context Layer)
- Catches directional contradictions
- Catches semantic shifts
Args:
premise: Ground truth statement
hypothesis: Claim to verify
Returns:
FactCheckResult
"""
# LAYER 1: NUMERIC HARD-CHECK
numeric_check = self._check_numeric_divergence(premise, hypothesis)
if numeric_check is not None:
# Numeric contradiction found - reject immediately
return numeric_check
# LAYER 2: NLI MODEL (or fallback)
if self.use_local_model and self.nli_pipeline:
return self._check_entailment_nli(premise, hypothesis)
elif self.llm:
return self._check_entailment_llm(premise, hypothesis)
else:
return self._check_entailment_fallback(premise, hypothesis)
def _check_numeric_divergence(
self,
premise: str,
hypothesis: str,
tolerance: float = 0.10
) -> Optional[FactCheckResult]:
"""
LAYER 1: Numeric Hard-Check (The "Sanity" Layer)
Extract all % and $ values from premise and hypothesis.
If abs(claim - truth) / truth > tolerance, return CONTRADICTION immediately.
DO NOT LET AN LLM DECIDE IF 500 EQUALS 8.
Args:
premise: Ground truth statement
hypothesis: Claim to verify
tolerance: Maximum allowed divergence (default 10%)
Returns:
FactCheckResult if numeric contradiction found, None otherwise
"""
import re
# Extract percentages (e.g., "500%", "8%", "5.5%")
premise_pcts = re.findall(r'(\d+(?:\.\d+)?)\s*%', premise)
hyp_pcts = re.findall(r'(\d+(?:\.\d+)?)\s*%', hypothesis)
# Extract dollar amounts (e.g., "$500", "$8.50")
premise_dollars = re.findall(r'\$\s*(\d+(?:\.\d+)?)', premise)
hyp_dollars = re.findall(r'\$\s*(\d+(?:\.\d+)?)', hypothesis)
# Extract plain numbers (e.g., "500", "8")
premise_nums = re.findall(r'\b(\d+(?:\.\d+)?)\b', premise)
hyp_nums = re.findall(r'\b(\d+(?:\.\d+)?)\b', hypothesis)
# Check percentages first (most common in financial claims)
if premise_pcts and hyp_pcts:
truth_val = float(premise_pcts[0])
claim_val = float(hyp_pcts[0])
# Calculate divergence
if truth_val > 0:
divergence = abs(claim_val - truth_val) / truth_val
else:
divergence = abs(claim_val - truth_val)
if divergence > tolerance:
return FactCheckResult(
valid=False,
label=EntailmentLabel.CONTRADICTION,
confidence=1.0, # Hard math, 100% confident
evidence=f"Numeric mismatch: Claim {claim_val}% vs Truth {truth_val}% (divergence: {divergence:.1%})"
)
# Check dollar amounts
if premise_dollars and hyp_dollars:
truth_val = float(premise_dollars[0])
claim_val = float(hyp_dollars[0])
if truth_val > 0:
divergence = abs(claim_val - truth_val) / truth_val
else:
divergence = abs(claim_val - truth_val)
if divergence > tolerance:
return FactCheckResult(
valid=False,
label=EntailmentLabel.CONTRADICTION,
confidence=1.0,
evidence=f"Numeric mismatch: Claim ${claim_val} vs Truth ${truth_val} (divergence: {divergence:.1%})"
)
# Check plain numbers (less reliable, only if no % or $)
if not premise_pcts and not premise_dollars and premise_nums and hyp_nums:
# Only check if numbers are large enough to be meaningful
truth_val = float(premise_nums[0])
claim_val = float(hyp_nums[0])
if truth_val >= 10: # Only check numbers >= 10 to avoid false positives
if truth_val > 0:
divergence = abs(claim_val - truth_val) / truth_val
else:
divergence = abs(claim_val - truth_val)
if divergence > tolerance:
return FactCheckResult(
valid=False,
label=EntailmentLabel.CONTRADICTION,
confidence=0.9, # Slightly less confident for plain numbers
evidence=f"Numeric mismatch: Claim {claim_val} vs Truth {truth_val} (divergence: {divergence:.1%})"
)
# No numeric contradiction found
return None
def _check_entailment_nli(
self,
premise: str,
hypothesis: str
) -> FactCheckResult:
"""Use DeBERTa NLI model for entailment checking."""
# Format for NLI: premise [SEP] hypothesis
input_text = f"{premise} [SEP] {hypothesis}"
# Run NLI
result = self.nli_pipeline(input_text)[0]
label_str = result['label'].lower()
confidence = result['score']
# Map to EntailmentLabel
if 'entail' in label_str:
label = EntailmentLabel.ENTAILMENT
valid = True
evidence = f"Claim entailed by ground truth: {premise}"
elif 'contradict' in label_str:
label = EntailmentLabel.CONTRADICTION
valid = False
evidence = f"Claim contradicts ground truth: {premise}"
else:
label = EntailmentLabel.NEUTRAL
valid = True # Neutral = can't disprove
evidence = f"Claim neither entailed nor contradicted: {premise}"
return FactCheckResult(
valid=valid,
label=label,
confidence=confidence,
evidence=evidence
)
def _check_entailment_llm(
self,
premise: str,
hypothesis: str
) -> FactCheckResult:
"""Fallback: Use LLM for entailment checking."""
prompt = f"""Determine if the Hypothesis is supported by the Premise.
Premise (Ground Truth): {premise}
Hypothesis (Claim): {hypothesis}
Respond in JSON:
{{
"entailment": "entailment" | "contradiction" | "neutral",
"confidence": 0.0-1.0,
"reasoning": "brief explanation"
}}"""
response = self.llm.invoke(prompt)
try:
result = json.loads(response.content)
label_map = {
"entailment": EntailmentLabel.ENTAILMENT,
"contradiction": EntailmentLabel.CONTRADICTION,
"neutral": EntailmentLabel.NEUTRAL
}
label = label_map.get(result["entailment"], EntailmentLabel.NEUTRAL)
valid = label != EntailmentLabel.CONTRADICTION
return FactCheckResult(
valid=valid,
label=label,
confidence=result["confidence"],
evidence=result["reasoning"]
)
except:
return self._check_entailment_fallback(premise, hypothesis)
def _check_entailment_fallback(
self,
premise: str,
hypothesis: str
) -> FactCheckResult:
"""Last resort: Simple keyword matching."""
# Extract direction words
increase_words = ["increase", "grew", "rose", "up", "gain", "higher"]
decrease_words = ["decrease", "fell", "dropped", "down", "loss", "lower"]
premise_dir = None
if any(w in premise.lower() for w in increase_words):
premise_dir = "increase"
elif any(w in premise.lower() for w in decrease_words):
premise_dir = "decrease"
hyp_dir = None
if any(w in hypothesis.lower() for w in increase_words):
hyp_dir = "increase"
elif any(w in hypothesis.lower() for w in decrease_words):
hyp_dir = "decrease"
# Check if directions match
if premise_dir and hyp_dir:
if premise_dir == hyp_dir:
return FactCheckResult(
valid=True,
label=EntailmentLabel.ENTAILMENT,
confidence=0.7,
evidence=f"Directions match: both {premise_dir}"
)
else:
return FactCheckResult(
valid=False,
label=EntailmentLabel.CONTRADICTION,
confidence=0.8,
evidence=f"Direction mismatch: {premise_dir} vs {hyp_dir}"
)
return FactCheckResult(
valid=True,
label=EntailmentLabel.NEUTRAL,
confidence=0.5,
evidence="Cannot determine entailment"
)
def _classify_argument(self, argument: str) -> str:
"""Classify argument type for appropriate validation."""
arg_lower = argument.lower()
if any(w in arg_lower for w in ["revenue", "earnings", "sales", "income"]):
return "revenue"
elif any(w in arg_lower for w in ["price", "stock", "share"]):
return "price"
elif any(w in arg_lower for w in ["rsi", "macd", "sma", "ema", "bollinger"]):
return "technical"
else:
return "qualitative"
def _get_cache_key(self, argument: str, trading_date: str) -> str:
"""Generate cache key from argument and date."""
# Hash argument + date
hash_input = f"{argument}_{trading_date}"
return hashlib.md5(hash_input.encode()).hexdigest()
def _add_to_cache(self, key: str, result: FactCheckResult):
"""Add result to cache with size limit."""
if len(self.cache) >= self.cache_size:
# Remove oldest entry (simple FIFO)
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
self.cache[key] = result
def get_cache_stats(self) -> Dict[str, int]:
"""Get cache statistics."""
return {
"size": len(self.cache),
"max_size": self.cache_size,
"hit_rate": self._calculate_hit_rate()
}
def _calculate_hit_rate(self) -> float:
"""Calculate cache hit rate."""
# This would need to track hits/misses in production
return 0.0
def clear_cache(self):
"""Clear cache (e.g., at end of trading day)."""
self.cache.clear()
# Example usage
if __name__ == "__main__":
checker = SemanticFactChecker(use_local_model=False) # Use fallback for demo
# Test: Contradictory claim
arguments = [
"Revenue fell by 5% last quarter",
"Strong earnings growth of 10%"
]
ground_truth = {
"revenue_growth_yoy": 0.05, # Actually grew 5%
"earnings_growth": 0.10
}
results = checker.validate_arguments(arguments, ground_truth, "2024-01-15")
for arg, result in results.items():
print(f"\nArgument: {arg}")
print(f"Valid: {result.valid}")
print(f"Label: {result.label.value}")
print(f"Evidence: {result.evidence}")
print(f"Cached: {result.cached}")