223 lines
8.7 KiB
Python
223 lines
8.7 KiB
Python
"""
|
|
Unit Tests for Semantic Fact Checker
|
|
|
|
Tests:
|
|
- NLI-based semantic contradiction detection
|
|
- Targeted validation (final arguments only)
|
|
- Hash-based caching
|
|
- "Revenue fell" vs "Revenue rose" detection
|
|
"""
|
|
|
|
import unittest
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
|
|
|
from tradingagents.validation.semantic_fact_checker import (
|
|
SemanticFactChecker,
|
|
FactCheckResult,
|
|
EntailmentLabel
|
|
)
|
|
|
|
|
|
class TestSemanticFactChecker(unittest.TestCase):
|
|
"""Test suite for semantic fact checking."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
# Use fallback mode (no NLI model) for testing
|
|
self.checker = SemanticFactChecker(use_local_model=False)
|
|
|
|
def test_validate_contradictory_revenue_claim(self):
|
|
"""CRITICAL: Test detection of semantic contradiction."""
|
|
# Ground truth: Revenue GREW 5%
|
|
# Claim: Revenue FELL 5%
|
|
# Expected: CONTRADICTION
|
|
|
|
arguments = ["Revenue fell by 5% last quarter"]
|
|
ground_truth = {"revenue_growth_yoy": 0.05} # Grew 5%
|
|
|
|
results = self.checker.validate_arguments(arguments, ground_truth, "2024-01-15")
|
|
result = results[arguments[0]]
|
|
|
|
self.assertFalse(result.valid, "Contradictory claim should be invalid")
|
|
self.assertEqual(result.label, EntailmentLabel.CONTRADICTION,
|
|
"Should detect contradiction")
|
|
self.assertIn("mismatch", result.evidence.lower(),
|
|
"Evidence should mention direction mismatch")
|
|
|
|
def test_validate_correct_revenue_claim(self):
|
|
"""Test validation of correct claim."""
|
|
arguments = ["Revenue increased by approximately 5%"]
|
|
ground_truth = {"revenue_growth_yoy": 0.05}
|
|
|
|
results = self.checker.validate_arguments(arguments, ground_truth, "2024-01-15")
|
|
result = results[arguments[0]]
|
|
|
|
self.assertTrue(result.valid, "Correct claim should be valid")
|
|
self.assertEqual(result.label, EntailmentLabel.ENTAILMENT,
|
|
"Should detect entailment")
|
|
|
|
def test_validate_price_increase_claim(self):
|
|
"""Test price movement validation."""
|
|
arguments = ["Stock price rose significantly"]
|
|
ground_truth = {"price_change_pct": 0.10} # 10% increase
|
|
|
|
results = self.checker.validate_arguments(arguments, ground_truth, "2024-01-15")
|
|
result = results[arguments[0]]
|
|
|
|
self.assertTrue(result.valid, "Price increase claim should be valid")
|
|
|
|
def test_validate_price_decrease_contradiction(self):
|
|
"""Test detection of price direction contradiction."""
|
|
arguments = ["Stock price fell sharply"]
|
|
ground_truth = {"price_change_pct": 0.10} # Actually rose 10%
|
|
|
|
results = self.checker.validate_arguments(arguments, ground_truth, "2024-01-15")
|
|
result = results[arguments[0]]
|
|
|
|
self.assertFalse(result.valid, "Contradictory price claim should be invalid")
|
|
self.assertEqual(result.label, EntailmentLabel.CONTRADICTION)
|
|
|
|
def test_validate_technical_indicator_claim(self):
|
|
"""Test technical indicator validation."""
|
|
arguments = ["RSI is at 45.2"]
|
|
ground_truth = {
|
|
"indicators": {
|
|
"RSI": 45.2
|
|
}
|
|
}
|
|
|
|
results = self.checker.validate_arguments(arguments, ground_truth, "2024-01-15")
|
|
result = results[arguments[0]]
|
|
|
|
self.assertTrue(result.valid, "Correct RSI value should be valid")
|
|
self.assertEqual(result.label, EntailmentLabel.ENTAILMENT)
|
|
|
|
def test_validate_technical_indicator_mismatch(self):
|
|
"""Test detection of incorrect technical indicator value."""
|
|
arguments = ["RSI is at 70"]
|
|
ground_truth = {
|
|
"indicators": {
|
|
"RSI": 45.2
|
|
}
|
|
}
|
|
|
|
results = self.checker.validate_arguments(arguments, ground_truth, "2024-01-15")
|
|
result = results[arguments[0]]
|
|
|
|
self.assertFalse(result.valid, "Incorrect RSI value should be invalid")
|
|
self.assertEqual(result.label, EntailmentLabel.CONTRADICTION)
|
|
|
|
def test_caching_same_argument(self):
|
|
"""Test that identical arguments are cached."""
|
|
arguments = ["Revenue grew 5%"]
|
|
ground_truth = {"revenue_growth_yoy": 0.05}
|
|
trading_date = "2024-01-15"
|
|
|
|
# First call - not cached
|
|
results1 = self.checker.validate_arguments(arguments, ground_truth, trading_date)
|
|
self.assertFalse(results1[arguments[0]].cached, "First call should not be cached")
|
|
|
|
# Second call - should be cached
|
|
results2 = self.checker.validate_arguments(arguments, ground_truth, trading_date)
|
|
self.assertTrue(results2[arguments[0]].cached, "Second call should be cached")
|
|
|
|
def test_caching_different_dates(self):
|
|
"""Test that cache is scoped by trading date."""
|
|
arguments = ["Revenue grew 5%"]
|
|
ground_truth = {"revenue_growth_yoy": 0.05}
|
|
|
|
# Same argument, different dates
|
|
results1 = self.checker.validate_arguments(arguments, ground_truth, "2024-01-15")
|
|
results2 = self.checker.validate_arguments(arguments, ground_truth, "2024-01-16")
|
|
|
|
# Both should not be cached (different dates)
|
|
self.assertFalse(results1[arguments[0]].cached)
|
|
self.assertFalse(results2[arguments[0]].cached)
|
|
|
|
def test_targeted_validation_multiple_arguments(self):
|
|
"""Test validation of multiple arguments (targeted, not full conversation)."""
|
|
arguments = [
|
|
"Revenue grew 5%",
|
|
"Earnings increased 10%",
|
|
"Price rose 3%"
|
|
]
|
|
|
|
ground_truth = {
|
|
"revenue_growth_yoy": 0.05,
|
|
"earnings_growth": 0.10,
|
|
"price_change_pct": 0.03
|
|
}
|
|
|
|
results = self.checker.validate_arguments(arguments, ground_truth, "2024-01-15")
|
|
|
|
# All should be valid
|
|
for arg in arguments:
|
|
self.assertTrue(results[arg].valid, f"Argument '{arg}' should be valid")
|
|
|
|
def test_qualitative_claim_neutral(self):
|
|
"""Test that qualitative claims return neutral."""
|
|
arguments = ["The company has strong leadership"]
|
|
ground_truth = {}
|
|
|
|
results = self.checker.validate_arguments(arguments, ground_truth, "2024-01-15")
|
|
result = results[arguments[0]]
|
|
|
|
self.assertTrue(result.valid, "Qualitative claims should be valid (can't verify)")
|
|
self.assertEqual(result.label, EntailmentLabel.NEUTRAL)
|
|
|
|
def test_missing_ground_truth_data(self):
|
|
"""Test handling of missing ground truth data."""
|
|
arguments = ["Revenue grew 5%"]
|
|
ground_truth = {} # No revenue data
|
|
|
|
results = self.checker.validate_arguments(arguments, ground_truth, "2024-01-15")
|
|
result = results[arguments[0]]
|
|
|
|
self.assertTrue(result.valid, "Should be valid when ground truth missing")
|
|
self.assertEqual(result.label, EntailmentLabel.NEUTRAL)
|
|
|
|
def test_cache_size_limit(self):
|
|
"""Test that cache respects size limit."""
|
|
checker = SemanticFactChecker(use_local_model=False, cache_size=5)
|
|
ground_truth = {"revenue_growth_yoy": 0.05}
|
|
|
|
# Add 10 arguments (exceeds cache size of 5)
|
|
for i in range(10):
|
|
arguments = [f"Revenue grew {i}%"]
|
|
checker.validate_arguments(arguments, ground_truth, "2024-01-15")
|
|
|
|
stats = checker.get_cache_stats()
|
|
self.assertLessEqual(stats["size"], 5, "Cache should not exceed max size")
|
|
|
|
def test_clear_cache(self):
|
|
"""Test cache clearing."""
|
|
arguments = ["Revenue grew 5%"]
|
|
ground_truth = {"revenue_growth_yoy": 0.05}
|
|
|
|
self.checker.validate_arguments(arguments, ground_truth, "2024-01-15")
|
|
self.assertGreater(len(self.checker.cache), 0, "Cache should have entries")
|
|
|
|
self.checker.clear_cache()
|
|
self.assertEqual(len(self.checker.cache), 0, "Cache should be empty after clear")
|
|
|
|
def test_classify_argument_types(self):
|
|
"""Test argument classification."""
|
|
test_cases = [
|
|
("Revenue grew 5%", "revenue"),
|
|
("Stock price rose", "price"),
|
|
("RSI is oversold", "technical"),
|
|
("Company has good management", "qualitative")
|
|
]
|
|
|
|
for argument, expected_type in test_cases:
|
|
result = self.checker._classify_argument(argument)
|
|
self.assertEqual(result, expected_type,
|
|
f"'{argument}' should be classified as '{expected_type}'")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(verbosity=2)
|