TradingAgents/tests/test_rag_isolator.py

222 lines
8.7 KiB
Python

"""
Unit Tests for RAG Isolator
Tests:
- Prompt creation with strict RAG enforcement
- Context formatting
- Response validation (knowledge contamination detection)
- Fact grounding
"""
import unittest
import sys
import os
# Add parent directory to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from tradingagents.dataflows.rag_isolator import RAGIsolator
class TestRAGIsolator(unittest.TestCase):
"""Test suite for RAGIsolator."""
def setUp(self):
"""Set up test fixtures."""
self.isolator = RAGIsolator(strict_mode=True)
self.context = {
"market_data": {
"close": 102.5,
"volume": 50000000,
"indicators": {
"RSI": 45.2,
"MACD": 0.8,
"50_SMA": 100.3
}
},
"news": [
{"summary": "Company ASSET_042 reported quarterly earnings"},
{"summary": "Product A sales exceeded expectations"}
],
"fundamentals": {
"revenue_growth": 0.05,
"earnings": 1.2,
"debt_to_equity": 0.3
},
"historical": {
"1m_return": 0.03,
"3m_return": 0.08,
"6m_return": 0.15
}
}
def test_create_isolated_prompt_strict_mode(self):
"""Test prompt creation in strict mode."""
query = "Should I buy this asset?"
prompt = self.isolator.create_isolated_prompt(query, self.context)
prompt_text = prompt.format(query=query)
# Check for strict mode instructions
self.assertIn("ONLY the information provided", prompt_text)
self.assertIn("DO NOT use any knowledge from your training data", prompt_text)
self.assertIn("INSUFFICIENT DATA", prompt_text)
def test_create_isolated_prompt_non_strict_mode(self):
"""Test prompt creation in non-strict mode."""
isolator = RAGIsolator(strict_mode=False)
query = "What is the trend?"
prompt = isolator.create_isolated_prompt(query, self.context)
prompt_text = prompt.format(query=query)
# Should not have strict warnings
self.assertNotIn("DO NOT use any knowledge from your training data", prompt_text)
def test_format_context_market_data(self):
"""Test context formatting includes market data."""
context_str = self.isolator._format_context(self.context)
self.assertIn("MARKET DATA", context_str)
self.assertIn("102.5", context_str)
self.assertIn("RSI", context_str)
self.assertIn("45.2", context_str)
def test_format_context_news(self):
"""Test context formatting includes news."""
context_str = self.isolator._format_context(self.context)
self.assertIn("NEWS SUMMARY", context_str)
self.assertIn("ASSET_042", context_str)
self.assertIn("Product A", context_str)
def test_format_context_fundamentals(self):
"""Test context formatting includes fundamentals."""
context_str = self.isolator._format_context(self.context)
self.assertIn("FUNDAMENTAL DATA", context_str)
self.assertIn("Revenue Growth", context_str)
self.assertIn("0.05", context_str)
def test_format_context_historical(self):
"""Test context formatting includes historical performance."""
context_str = self.isolator._format_context(self.context)
self.assertIn("HISTORICAL PERFORMANCE", context_str)
self.assertIn("1-Month Return", context_str)
self.assertIn("0.03", context_str)
def test_validate_response_clean(self):
"""Test validation of clean response (no violations)."""
response = "Based on the RSI of 45.2 and positive revenue growth of 5%, the asset shows moderate strength."
result = self.isolator.validate_response(response, self.context)
self.assertTrue(result["valid"], "Clean response should be valid")
self.assertEqual(len(result["violations"]), 0, "Should have no violations")
self.assertEqual(result["confidence"], 1.0, "Confidence should be 1.0")
def test_validate_response_company_name_leak(self):
"""Test detection of company name leakage."""
response = "This is clearly Apple based on the fundamentals."
result = self.isolator.validate_response(response, self.context)
self.assertFalse(result["valid"], "Should be invalid")
self.assertGreater(len(result["violations"]), 0, "Should have violations")
self.assertIn("Apple", str(result["violations"]), "Should detect Apple mention")
def test_validate_response_product_name_leak(self):
"""Test detection of product name leakage."""
response = "iPhone sales are driving growth."
result = self.isolator.validate_response(response, self.context)
self.assertFalse(result["valid"], "Should be invalid")
self.assertIn("iPhone", str(result["violations"]), "Should detect iPhone mention")
def test_validate_response_absolute_price_leak(self):
"""Test detection of absolute dollar prices."""
response = "The stock is trading at $480 which is expensive."
result = self.isolator.validate_response(response, self.context)
self.assertFalse(result["valid"], "Should be invalid")
self.assertIn("$480", str(result["violations"]), "Should detect absolute price")
def test_validate_response_knowledge_phrase_leak(self):
"""Test detection of pre-trained knowledge phrases."""
response = "Based on my knowledge, this company typically performs well."
result = self.isolator.validate_response(response, self.context)
self.assertFalse(result["valid"], "Should be invalid")
self.assertTrue(
any("knowledge" in v.lower() for v in result["violations"]),
"Should detect knowledge phrase"
)
def test_validate_response_multiple_violations(self):
"""Test confidence reduction with multiple violations."""
response = "Apple's iPhone sales at $500 are strong based on my knowledge."
result = self.isolator.validate_response(response, self.context)
self.assertFalse(result["valid"], "Should be invalid")
self.assertGreaterEqual(len(result["violations"]), 3, "Should have multiple violations")
self.assertLess(result["confidence"], 1.0, "Confidence should be reduced")
def test_create_fact_grounded_prompt_no_inference(self):
"""Test fact-grounded prompt without inference."""
facts = [
"Revenue grew 5% YoY",
"Earnings per share: $1.20",
"Debt-to-equity ratio: 0.3"
]
query = "What is the revenue growth?"
prompt = self.isolator.create_fact_grounded_prompt(query, facts, allow_inference=False)
self.assertIn("Revenue grew 5% YoY", prompt)
self.assertIn("Do not infer", prompt)
def test_create_fact_grounded_prompt_with_inference(self):
"""Test fact-grounded prompt with inference allowed."""
facts = [
"Revenue grew 5% YoY",
"Costs decreased 3%"
]
query = "What happened to profit margins?"
prompt = self.isolator.create_fact_grounded_prompt(query, facts, allow_inference=True)
self.assertIn("may make logical inferences", prompt)
self.assertIn("clearly state when you are inferring", prompt)
def test_validate_response_case_insensitive(self):
"""Test that validation is case-insensitive."""
response = "This is APPLE stock."
result = self.isolator.validate_response(response, self.context)
self.assertFalse(result["valid"], "Should detect case-insensitive company names")
def test_empty_context(self):
"""Test handling of empty context."""
empty_context = {}
context_str = self.isolator._format_context(empty_context)
# Should not crash, just return empty sections
self.assertIsInstance(context_str, str)
def test_partial_context(self):
"""Test handling of partial context (missing sections)."""
partial_context = {
"market_data": {
"close": 100.0
}
}
context_str = self.isolator._format_context(partial_context)
self.assertIn("MARKET DATA", context_str)
self.assertNotIn("NEWS SUMMARY", context_str)
if __name__ == '__main__':
# Run tests
unittest.main(verbosity=2)