596 lines
20 KiB
Python
596 lines
20 KiB
Python
"""
|
|
Production Semantic Fact Checker with NLI
|
|
|
|
Features:
|
|
- DeBERTa-based entailment checking
|
|
- Targeted validation (final arguments only, not full conversation)
|
|
- Hash-based caching to prevent redundant checks
|
|
- Catches semantic contradictions ("fell" vs "rose")
|
|
"""
|
|
|
|
from typing import Dict, Any, List, Optional
|
|
import hashlib
|
|
import json
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
import re
|
|
|
|
|
|
class EntailmentLabel(Enum):
|
|
"""NLI entailment labels."""
|
|
ENTAILMENT = "entailment"
|
|
CONTRADICTION = "contradiction"
|
|
NEUTRAL = "neutral"
|
|
|
|
|
|
@dataclass
|
|
class FactCheckResult:
|
|
"""Result of fact checking."""
|
|
valid: bool
|
|
label: EntailmentLabel
|
|
confidence: float
|
|
evidence: str
|
|
cached: bool = False
|
|
|
|
|
|
class SemanticFactChecker:
|
|
"""
|
|
Validate claims using NLI (Natural Language Inference).
|
|
|
|
CRITICAL OPTIMIZATIONS:
|
|
1. Targeted validation: Only check final arguments, not full conversation
|
|
2. Caching: Hash claims and cache results per trading day
|
|
3. Batch processing: Check multiple claims in one NLI call
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "microsoft/deberta-v3-small",
|
|
use_local_model: bool = True,
|
|
cache_size: int = 10000
|
|
):
|
|
"""
|
|
Initialize fact checker.
|
|
|
|
Args:
|
|
model_name: HuggingFace NLI model
|
|
use_local_model: Try to load local model, fallback to LLM
|
|
cache_size: Maximum cache entries
|
|
"""
|
|
self.use_local_model = use_local_model
|
|
self.nli_pipeline = None
|
|
self.llm = None
|
|
|
|
# Cache: {claim_hash: FactCheckResult}
|
|
self.cache = {}
|
|
self.cache_size = cache_size
|
|
|
|
# Try to load NLI model
|
|
if use_local_model:
|
|
try:
|
|
from transformers import pipeline
|
|
import torch
|
|
|
|
self.nli_pipeline = pipeline(
|
|
"text-classification",
|
|
model=model_name,
|
|
device=0 if torch.cuda.is_available() else -1
|
|
)
|
|
print(f"✅ Loaded NLI model: {model_name}")
|
|
except Exception as e:
|
|
print(f"⚠️ Could not load NLI model: {e}")
|
|
print(" Falling back to LLM-based validation")
|
|
self.use_local_model = False
|
|
|
|
def set_llm(self, llm):
|
|
"""Set LLM for fallback validation."""
|
|
self.llm = llm
|
|
|
|
def validate_arguments(
|
|
self,
|
|
arguments: List[str],
|
|
ground_truth: Dict[str, Any],
|
|
trading_date: str
|
|
) -> Dict[str, FactCheckResult]:
|
|
"""
|
|
Validate a list of arguments against ground truth.
|
|
|
|
TARGETED VALIDATION: Only validates final arguments, not full conversation.
|
|
|
|
Args:
|
|
arguments: List of claims to validate (from JSON "key_arguments")
|
|
ground_truth: Structured ground truth data
|
|
trading_date: Date for cache scoping
|
|
|
|
Returns:
|
|
Dict mapping argument to FactCheckResult
|
|
"""
|
|
results = {}
|
|
|
|
for argument in arguments:
|
|
# Check cache first
|
|
cache_key = self._get_cache_key(argument, trading_date)
|
|
|
|
if cache_key in self.cache:
|
|
result = self.cache[cache_key]
|
|
result.cached = True
|
|
results[argument] = result
|
|
continue
|
|
|
|
# Validate uncached argument
|
|
result = self._validate_single_argument(argument, ground_truth)
|
|
|
|
# Cache result
|
|
self._add_to_cache(cache_key, result)
|
|
results[argument] = result
|
|
|
|
return results
|
|
|
|
def _validate_single_argument(
|
|
self,
|
|
argument: str,
|
|
ground_truth: Dict[str, Any]
|
|
) -> FactCheckResult:
|
|
"""
|
|
Validate a single argument.
|
|
|
|
Args:
|
|
argument: Claim to validate
|
|
ground_truth: Ground truth data
|
|
|
|
Returns:
|
|
FactCheckResult
|
|
"""
|
|
# Classify argument type
|
|
arg_type = self._classify_argument(argument)
|
|
|
|
if arg_type == "revenue":
|
|
return self._validate_revenue_claim(argument, ground_truth)
|
|
elif arg_type == "price":
|
|
return self._validate_price_claim(argument, ground_truth)
|
|
elif arg_type == "technical":
|
|
return self._validate_technical_claim(argument, ground_truth)
|
|
else:
|
|
# Cannot validate qualitative claims
|
|
return FactCheckResult(
|
|
valid=True, # Assume valid if can't verify
|
|
label=EntailmentLabel.NEUTRAL,
|
|
confidence=0.5,
|
|
evidence="Qualitative claim - cannot verify"
|
|
)
|
|
|
|
def _validate_revenue_claim(
|
|
self,
|
|
claim: str,
|
|
ground_truth: Dict[str, Any]
|
|
) -> FactCheckResult:
|
|
"""
|
|
Validate revenue-related claim using NLI.
|
|
|
|
Example:
|
|
Claim: "Revenue fell 5%"
|
|
Truth: revenue_growth_yoy = 0.05 (grew 5%)
|
|
Result: CONTRADICTION
|
|
"""
|
|
# Extract ground truth
|
|
revenue_growth = ground_truth.get("revenue_growth_yoy")
|
|
if revenue_growth is None:
|
|
return FactCheckResult(
|
|
valid=True,
|
|
label=EntailmentLabel.NEUTRAL,
|
|
confidence=0.0,
|
|
evidence="No revenue data available"
|
|
)
|
|
|
|
# Construct premise from ground truth
|
|
if revenue_growth > 0:
|
|
premise = f"Revenue increased by {abs(revenue_growth):.1%} year-over-year."
|
|
elif revenue_growth < 0:
|
|
premise = f"Revenue decreased by {abs(revenue_growth):.1%} year-over-year."
|
|
else:
|
|
premise = "Revenue remained flat year-over-year."
|
|
|
|
# Check entailment
|
|
return self._check_entailment(premise, claim)
|
|
|
|
def _validate_price_claim(
|
|
self,
|
|
claim: str,
|
|
ground_truth: Dict[str, Any]
|
|
) -> FactCheckResult:
|
|
"""Validate price movement claim."""
|
|
price_change = ground_truth.get("price_change_pct")
|
|
if price_change is None:
|
|
return FactCheckResult(
|
|
valid=True,
|
|
label=EntailmentLabel.NEUTRAL,
|
|
confidence=0.0,
|
|
evidence="No price data available"
|
|
)
|
|
|
|
# Construct premise
|
|
if price_change > 0:
|
|
premise = f"Price increased by {abs(price_change):.1%}."
|
|
elif price_change < 0:
|
|
premise = f"Price decreased by {abs(price_change):.1%}."
|
|
else:
|
|
premise = "Price remained unchanged."
|
|
|
|
return self._check_entailment(premise, claim)
|
|
|
|
def _validate_technical_claim(
|
|
self,
|
|
claim: str,
|
|
ground_truth: Dict[str, Any]
|
|
) -> FactCheckResult:
|
|
"""Validate technical indicator claim (simple numeric check)."""
|
|
# For technical indicators, use simple numeric comparison
|
|
# Extract number from claim
|
|
import re
|
|
claim_numbers = re.findall(r'\d+(?:\.\d+)?', claim)
|
|
|
|
if not claim_numbers:
|
|
return FactCheckResult(
|
|
valid=True,
|
|
label=EntailmentLabel.NEUTRAL,
|
|
confidence=0.5,
|
|
evidence="No numbers in claim"
|
|
)
|
|
|
|
# Check if RSI/MACD values match ground truth
|
|
indicators = ground_truth.get("indicators", {})
|
|
|
|
# Simple heuristic: if claim mentions RSI and ground truth has RSI, compare
|
|
if "rsi" in claim.lower() and "RSI" in indicators:
|
|
claim_val = float(claim_numbers[0])
|
|
truth_val = indicators["RSI"]
|
|
|
|
if abs(claim_val - truth_val) < 2.0: # Within 2 points
|
|
return FactCheckResult(
|
|
valid=True,
|
|
label=EntailmentLabel.ENTAILMENT,
|
|
confidence=0.9,
|
|
evidence=f"RSI values match: {claim_val} ≈ {truth_val}"
|
|
)
|
|
else:
|
|
return FactCheckResult(
|
|
valid=False,
|
|
label=EntailmentLabel.CONTRADICTION,
|
|
confidence=0.8,
|
|
evidence=f"RSI mismatch: claimed {claim_val}, actual {truth_val}"
|
|
)
|
|
|
|
return FactCheckResult(
|
|
valid=True,
|
|
label=EntailmentLabel.NEUTRAL,
|
|
confidence=0.5,
|
|
evidence="Cannot verify technical claim"
|
|
)
|
|
|
|
def _check_entailment(
|
|
self,
|
|
premise: str,
|
|
hypothesis: str
|
|
) -> FactCheckResult:
|
|
"""
|
|
Check if premise entails hypothesis using HYBRID VALIDATION.
|
|
|
|
LAYER 1: Numeric Hard-Check (Sanity Layer)
|
|
- Extract all % and $ values
|
|
- If divergence > 10%, reject immediately
|
|
- Do NOT let LLM decide if 500 equals 8
|
|
|
|
LAYER 2: DeBERTa NLI Model (Context Layer)
|
|
- Catches directional contradictions
|
|
- Catches semantic shifts
|
|
|
|
Args:
|
|
premise: Ground truth statement
|
|
hypothesis: Claim to verify
|
|
|
|
Returns:
|
|
FactCheckResult
|
|
"""
|
|
# LAYER 1: NUMERIC HARD-CHECK
|
|
numeric_check = self._check_numeric_divergence(premise, hypothesis)
|
|
if numeric_check is not None:
|
|
# Numeric contradiction found - reject immediately
|
|
return numeric_check
|
|
|
|
# LAYER 2: NLI MODEL (or fallback)
|
|
if self.use_local_model and self.nli_pipeline:
|
|
return self._check_entailment_nli(premise, hypothesis)
|
|
elif self.llm:
|
|
return self._check_entailment_llm(premise, hypothesis)
|
|
else:
|
|
return self._check_entailment_fallback(premise, hypothesis)
|
|
|
|
def _check_numeric_divergence(
|
|
self,
|
|
premise: str,
|
|
hypothesis: str,
|
|
tolerance: float = 0.10
|
|
) -> Optional[FactCheckResult]:
|
|
"""
|
|
LAYER 1: Numeric Hard-Check (The "Sanity" Layer)
|
|
|
|
Extract all % and $ values from premise and hypothesis.
|
|
If abs(claim - truth) / truth > tolerance, return CONTRADICTION immediately.
|
|
|
|
DO NOT LET AN LLM DECIDE IF 500 EQUALS 8.
|
|
|
|
Args:
|
|
premise: Ground truth statement
|
|
hypothesis: Claim to verify
|
|
tolerance: Maximum allowed divergence (default 10%)
|
|
|
|
Returns:
|
|
FactCheckResult if numeric contradiction found, None otherwise
|
|
"""
|
|
import re
|
|
|
|
# Extract percentages (e.g., "500%", "8%", "5.5%")
|
|
premise_pcts = re.findall(r'(\d+(?:\.\d+)?)\s*%', premise)
|
|
hyp_pcts = re.findall(r'(\d+(?:\.\d+)?)\s*%', hypothesis)
|
|
|
|
# Extract dollar amounts (e.g., "$500", "$8.50")
|
|
premise_dollars = re.findall(r'\$\s*(\d+(?:\.\d+)?)', premise)
|
|
hyp_dollars = re.findall(r'\$\s*(\d+(?:\.\d+)?)', hypothesis)
|
|
|
|
# Extract plain numbers (e.g., "500", "8")
|
|
premise_nums = re.findall(r'\b(\d+(?:\.\d+)?)\b', premise)
|
|
hyp_nums = re.findall(r'\b(\d+(?:\.\d+)?)\b', hypothesis)
|
|
|
|
# Check percentages first (most common in financial claims)
|
|
if premise_pcts and hyp_pcts:
|
|
truth_val = float(premise_pcts[0])
|
|
claim_val = float(hyp_pcts[0])
|
|
|
|
# Calculate divergence
|
|
if truth_val > 0:
|
|
divergence = abs(claim_val - truth_val) / truth_val
|
|
else:
|
|
divergence = abs(claim_val - truth_val)
|
|
|
|
if divergence > tolerance:
|
|
return FactCheckResult(
|
|
valid=False,
|
|
label=EntailmentLabel.CONTRADICTION,
|
|
confidence=1.0, # Hard math, 100% confident
|
|
evidence=f"Numeric mismatch: Claim {claim_val}% vs Truth {truth_val}% (divergence: {divergence:.1%})"
|
|
)
|
|
|
|
# Check dollar amounts
|
|
if premise_dollars and hyp_dollars:
|
|
truth_val = float(premise_dollars[0])
|
|
claim_val = float(hyp_dollars[0])
|
|
|
|
if truth_val > 0:
|
|
divergence = abs(claim_val - truth_val) / truth_val
|
|
else:
|
|
divergence = abs(claim_val - truth_val)
|
|
|
|
if divergence > tolerance:
|
|
return FactCheckResult(
|
|
valid=False,
|
|
label=EntailmentLabel.CONTRADICTION,
|
|
confidence=1.0,
|
|
evidence=f"Numeric mismatch: Claim ${claim_val} vs Truth ${truth_val} (divergence: {divergence:.1%})"
|
|
)
|
|
|
|
# Check plain numbers (less reliable, only if no % or $)
|
|
if not premise_pcts and not premise_dollars and premise_nums and hyp_nums:
|
|
# Only check if numbers are large enough to be meaningful
|
|
truth_val = float(premise_nums[0])
|
|
claim_val = float(hyp_nums[0])
|
|
|
|
if truth_val >= 10: # Only check numbers >= 10 to avoid false positives
|
|
if truth_val > 0:
|
|
divergence = abs(claim_val - truth_val) / truth_val
|
|
else:
|
|
divergence = abs(claim_val - truth_val)
|
|
|
|
if divergence > tolerance:
|
|
return FactCheckResult(
|
|
valid=False,
|
|
label=EntailmentLabel.CONTRADICTION,
|
|
confidence=0.9, # Slightly less confident for plain numbers
|
|
evidence=f"Numeric mismatch: Claim {claim_val} vs Truth {truth_val} (divergence: {divergence:.1%})"
|
|
)
|
|
|
|
# No numeric contradiction found
|
|
return None
|
|
|
|
def _check_entailment_nli(
|
|
self,
|
|
premise: str,
|
|
hypothesis: str
|
|
) -> FactCheckResult:
|
|
"""Use DeBERTa NLI model for entailment checking."""
|
|
# Format for NLI: premise [SEP] hypothesis
|
|
input_text = f"{premise} [SEP] {hypothesis}"
|
|
|
|
# Run NLI
|
|
result = self.nli_pipeline(input_text)[0]
|
|
|
|
label_str = result['label'].lower()
|
|
confidence = result['score']
|
|
|
|
# Map to EntailmentLabel
|
|
if 'entail' in label_str:
|
|
label = EntailmentLabel.ENTAILMENT
|
|
valid = True
|
|
evidence = f"Claim entailed by ground truth: {premise}"
|
|
elif 'contradict' in label_str:
|
|
label = EntailmentLabel.CONTRADICTION
|
|
valid = False
|
|
evidence = f"Claim contradicts ground truth: {premise}"
|
|
else:
|
|
label = EntailmentLabel.NEUTRAL
|
|
valid = True # Neutral = can't disprove
|
|
evidence = f"Claim neither entailed nor contradicted: {premise}"
|
|
|
|
return FactCheckResult(
|
|
valid=valid,
|
|
label=label,
|
|
confidence=confidence,
|
|
evidence=evidence
|
|
)
|
|
|
|
def _check_entailment_llm(
|
|
self,
|
|
premise: str,
|
|
hypothesis: str
|
|
) -> FactCheckResult:
|
|
"""Fallback: Use LLM for entailment checking."""
|
|
prompt = f"""Determine if the Hypothesis is supported by the Premise.
|
|
|
|
Premise (Ground Truth): {premise}
|
|
Hypothesis (Claim): {hypothesis}
|
|
|
|
Respond in JSON:
|
|
{{
|
|
"entailment": "entailment" | "contradiction" | "neutral",
|
|
"confidence": 0.0-1.0,
|
|
"reasoning": "brief explanation"
|
|
}}"""
|
|
|
|
response = self.llm.invoke(prompt)
|
|
|
|
try:
|
|
result = json.loads(response.content)
|
|
|
|
label_map = {
|
|
"entailment": EntailmentLabel.ENTAILMENT,
|
|
"contradiction": EntailmentLabel.CONTRADICTION,
|
|
"neutral": EntailmentLabel.NEUTRAL
|
|
}
|
|
|
|
label = label_map.get(result["entailment"], EntailmentLabel.NEUTRAL)
|
|
valid = label != EntailmentLabel.CONTRADICTION
|
|
|
|
return FactCheckResult(
|
|
valid=valid,
|
|
label=label,
|
|
confidence=result["confidence"],
|
|
evidence=result["reasoning"]
|
|
)
|
|
except:
|
|
return self._check_entailment_fallback(premise, hypothesis)
|
|
|
|
def _check_entailment_fallback(
|
|
self,
|
|
premise: str,
|
|
hypothesis: str
|
|
) -> FactCheckResult:
|
|
"""Last resort: Simple keyword matching."""
|
|
# Extract direction words
|
|
increase_words = ["increase", "grew", "rose", "up", "gain", "higher"]
|
|
decrease_words = ["decrease", "fell", "dropped", "down", "loss", "lower"]
|
|
|
|
premise_dir = None
|
|
if any(w in premise.lower() for w in increase_words):
|
|
premise_dir = "increase"
|
|
elif any(w in premise.lower() for w in decrease_words):
|
|
premise_dir = "decrease"
|
|
|
|
hyp_dir = None
|
|
if any(w in hypothesis.lower() for w in increase_words):
|
|
hyp_dir = "increase"
|
|
elif any(w in hypothesis.lower() for w in decrease_words):
|
|
hyp_dir = "decrease"
|
|
|
|
# Check if directions match
|
|
if premise_dir and hyp_dir:
|
|
if premise_dir == hyp_dir:
|
|
return FactCheckResult(
|
|
valid=True,
|
|
label=EntailmentLabel.ENTAILMENT,
|
|
confidence=0.7,
|
|
evidence=f"Directions match: both {premise_dir}"
|
|
)
|
|
else:
|
|
return FactCheckResult(
|
|
valid=False,
|
|
label=EntailmentLabel.CONTRADICTION,
|
|
confidence=0.8,
|
|
evidence=f"Direction mismatch: {premise_dir} vs {hyp_dir}"
|
|
)
|
|
|
|
return FactCheckResult(
|
|
valid=True,
|
|
label=EntailmentLabel.NEUTRAL,
|
|
confidence=0.5,
|
|
evidence="Cannot determine entailment"
|
|
)
|
|
|
|
def _classify_argument(self, argument: str) -> str:
|
|
"""Classify argument type for appropriate validation."""
|
|
arg_lower = argument.lower()
|
|
|
|
if any(w in arg_lower for w in ["revenue", "earnings", "sales", "income"]):
|
|
return "revenue"
|
|
elif any(w in arg_lower for w in ["price", "stock", "share"]):
|
|
return "price"
|
|
elif any(w in arg_lower for w in ["rsi", "macd", "sma", "ema", "bollinger"]):
|
|
return "technical"
|
|
else:
|
|
return "qualitative"
|
|
|
|
def _get_cache_key(self, argument: str, trading_date: str) -> str:
|
|
"""Generate cache key from argument and date."""
|
|
# Hash argument + date
|
|
hash_input = f"{argument}_{trading_date}"
|
|
return hashlib.md5(hash_input.encode()).hexdigest()
|
|
|
|
def _add_to_cache(self, key: str, result: FactCheckResult):
|
|
"""Add result to cache with size limit."""
|
|
if len(self.cache) >= self.cache_size:
|
|
# Remove oldest entry (simple FIFO)
|
|
oldest_key = next(iter(self.cache))
|
|
del self.cache[oldest_key]
|
|
|
|
self.cache[key] = result
|
|
|
|
def get_cache_stats(self) -> Dict[str, int]:
|
|
"""Get cache statistics."""
|
|
return {
|
|
"size": len(self.cache),
|
|
"max_size": self.cache_size,
|
|
"hit_rate": self._calculate_hit_rate()
|
|
}
|
|
|
|
def _calculate_hit_rate(self) -> float:
|
|
"""Calculate cache hit rate."""
|
|
# This would need to track hits/misses in production
|
|
return 0.0
|
|
|
|
def clear_cache(self):
|
|
"""Clear cache (e.g., at end of trading day)."""
|
|
self.cache.clear()
|
|
|
|
|
|
# Example usage
|
|
if __name__ == "__main__":
|
|
checker = SemanticFactChecker(use_local_model=False) # Use fallback for demo
|
|
|
|
# Test: Contradictory claim
|
|
arguments = [
|
|
"Revenue fell by 5% last quarter",
|
|
"Strong earnings growth of 10%"
|
|
]
|
|
|
|
ground_truth = {
|
|
"revenue_growth_yoy": 0.05, # Actually grew 5%
|
|
"earnings_growth": 0.10
|
|
}
|
|
|
|
results = checker.validate_arguments(arguments, ground_truth, "2024-01-15")
|
|
|
|
for arg, result in results.items():
|
|
print(f"\nArgument: {arg}")
|
|
print(f"Valid: {result.valid}")
|
|
print(f"Label: {result.label.value}")
|
|
print(f"Evidence: {result.evidence}")
|
|
print(f"Cached: {result.cached}")
|