222 lines
8.7 KiB
Python
222 lines
8.7 KiB
Python
"""
|
|
Unit Tests for RAG Isolator
|
|
|
|
Tests:
|
|
- Prompt creation with strict RAG enforcement
|
|
- Context formatting
|
|
- Response validation (knowledge contamination detection)
|
|
- Fact grounding
|
|
"""
|
|
|
|
import unittest
|
|
import sys
|
|
import os
|
|
|
|
# Add parent directory to path
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
|
|
|
from tradingagents.dataflows.rag_isolator import RAGIsolator
|
|
|
|
|
|
class TestRAGIsolator(unittest.TestCase):
|
|
"""Test suite for RAGIsolator."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.isolator = RAGIsolator(strict_mode=True)
|
|
self.context = {
|
|
"market_data": {
|
|
"close": 102.5,
|
|
"volume": 50000000,
|
|
"indicators": {
|
|
"RSI": 45.2,
|
|
"MACD": 0.8,
|
|
"50_SMA": 100.3
|
|
}
|
|
},
|
|
"news": [
|
|
{"summary": "Company ASSET_042 reported quarterly earnings"},
|
|
{"summary": "Product A sales exceeded expectations"}
|
|
],
|
|
"fundamentals": {
|
|
"revenue_growth": 0.05,
|
|
"earnings": 1.2,
|
|
"debt_to_equity": 0.3
|
|
},
|
|
"historical": {
|
|
"1m_return": 0.03,
|
|
"3m_return": 0.08,
|
|
"6m_return": 0.15
|
|
}
|
|
}
|
|
|
|
def test_create_isolated_prompt_strict_mode(self):
|
|
"""Test prompt creation in strict mode."""
|
|
query = "Should I buy this asset?"
|
|
prompt = self.isolator.create_isolated_prompt(query, self.context)
|
|
|
|
prompt_text = prompt.format(query=query)
|
|
|
|
# Check for strict mode instructions
|
|
self.assertIn("ONLY the information provided", prompt_text)
|
|
self.assertIn("DO NOT use any knowledge from your training data", prompt_text)
|
|
self.assertIn("INSUFFICIENT DATA", prompt_text)
|
|
|
|
def test_create_isolated_prompt_non_strict_mode(self):
|
|
"""Test prompt creation in non-strict mode."""
|
|
isolator = RAGIsolator(strict_mode=False)
|
|
query = "What is the trend?"
|
|
prompt = isolator.create_isolated_prompt(query, self.context)
|
|
|
|
prompt_text = prompt.format(query=query)
|
|
|
|
# Should not have strict warnings
|
|
self.assertNotIn("DO NOT use any knowledge from your training data", prompt_text)
|
|
|
|
def test_format_context_market_data(self):
|
|
"""Test context formatting includes market data."""
|
|
context_str = self.isolator._format_context(self.context)
|
|
|
|
self.assertIn("MARKET DATA", context_str)
|
|
self.assertIn("102.5", context_str)
|
|
self.assertIn("RSI", context_str)
|
|
self.assertIn("45.2", context_str)
|
|
|
|
def test_format_context_news(self):
|
|
"""Test context formatting includes news."""
|
|
context_str = self.isolator._format_context(self.context)
|
|
|
|
self.assertIn("NEWS SUMMARY", context_str)
|
|
self.assertIn("ASSET_042", context_str)
|
|
self.assertIn("Product A", context_str)
|
|
|
|
def test_format_context_fundamentals(self):
|
|
"""Test context formatting includes fundamentals."""
|
|
context_str = self.isolator._format_context(self.context)
|
|
|
|
self.assertIn("FUNDAMENTAL DATA", context_str)
|
|
self.assertIn("Revenue Growth", context_str)
|
|
self.assertIn("0.05", context_str)
|
|
|
|
def test_format_context_historical(self):
|
|
"""Test context formatting includes historical performance."""
|
|
context_str = self.isolator._format_context(self.context)
|
|
|
|
self.assertIn("HISTORICAL PERFORMANCE", context_str)
|
|
self.assertIn("1-Month Return", context_str)
|
|
self.assertIn("0.03", context_str)
|
|
|
|
def test_validate_response_clean(self):
|
|
"""Test validation of clean response (no violations)."""
|
|
response = "Based on the RSI of 45.2 and positive revenue growth of 5%, the asset shows moderate strength."
|
|
result = self.isolator.validate_response(response, self.context)
|
|
|
|
self.assertTrue(result["valid"], "Clean response should be valid")
|
|
self.assertEqual(len(result["violations"]), 0, "Should have no violations")
|
|
self.assertEqual(result["confidence"], 1.0, "Confidence should be 1.0")
|
|
|
|
def test_validate_response_company_name_leak(self):
|
|
"""Test detection of company name leakage."""
|
|
response = "This is clearly Apple based on the fundamentals."
|
|
result = self.isolator.validate_response(response, self.context)
|
|
|
|
self.assertFalse(result["valid"], "Should be invalid")
|
|
self.assertGreater(len(result["violations"]), 0, "Should have violations")
|
|
self.assertIn("Apple", str(result["violations"]), "Should detect Apple mention")
|
|
|
|
def test_validate_response_product_name_leak(self):
|
|
"""Test detection of product name leakage."""
|
|
response = "iPhone sales are driving growth."
|
|
result = self.isolator.validate_response(response, self.context)
|
|
|
|
self.assertFalse(result["valid"], "Should be invalid")
|
|
self.assertIn("iPhone", str(result["violations"]), "Should detect iPhone mention")
|
|
|
|
def test_validate_response_absolute_price_leak(self):
|
|
"""Test detection of absolute dollar prices."""
|
|
response = "The stock is trading at $480 which is expensive."
|
|
result = self.isolator.validate_response(response, self.context)
|
|
|
|
self.assertFalse(result["valid"], "Should be invalid")
|
|
self.assertIn("$480", str(result["violations"]), "Should detect absolute price")
|
|
|
|
def test_validate_response_knowledge_phrase_leak(self):
|
|
"""Test detection of pre-trained knowledge phrases."""
|
|
response = "Based on my knowledge, this company typically performs well."
|
|
result = self.isolator.validate_response(response, self.context)
|
|
|
|
self.assertFalse(result["valid"], "Should be invalid")
|
|
self.assertTrue(
|
|
any("knowledge" in v.lower() for v in result["violations"]),
|
|
"Should detect knowledge phrase"
|
|
)
|
|
|
|
def test_validate_response_multiple_violations(self):
|
|
"""Test confidence reduction with multiple violations."""
|
|
response = "Apple's iPhone sales at $500 are strong based on my knowledge."
|
|
result = self.isolator.validate_response(response, self.context)
|
|
|
|
self.assertFalse(result["valid"], "Should be invalid")
|
|
self.assertGreaterEqual(len(result["violations"]), 3, "Should have multiple violations")
|
|
self.assertLess(result["confidence"], 1.0, "Confidence should be reduced")
|
|
|
|
def test_create_fact_grounded_prompt_no_inference(self):
|
|
"""Test fact-grounded prompt without inference."""
|
|
facts = [
|
|
"Revenue grew 5% YoY",
|
|
"Earnings per share: $1.20",
|
|
"Debt-to-equity ratio: 0.3"
|
|
]
|
|
query = "What is the revenue growth?"
|
|
|
|
prompt = self.isolator.create_fact_grounded_prompt(query, facts, allow_inference=False)
|
|
|
|
self.assertIn("Revenue grew 5% YoY", prompt)
|
|
self.assertIn("Do not infer", prompt)
|
|
|
|
def test_create_fact_grounded_prompt_with_inference(self):
|
|
"""Test fact-grounded prompt with inference allowed."""
|
|
facts = [
|
|
"Revenue grew 5% YoY",
|
|
"Costs decreased 3%"
|
|
]
|
|
query = "What happened to profit margins?"
|
|
|
|
prompt = self.isolator.create_fact_grounded_prompt(query, facts, allow_inference=True)
|
|
|
|
self.assertIn("may make logical inferences", prompt)
|
|
self.assertIn("clearly state when you are inferring", prompt)
|
|
|
|
def test_validate_response_case_insensitive(self):
|
|
"""Test that validation is case-insensitive."""
|
|
response = "This is APPLE stock."
|
|
result = self.isolator.validate_response(response, self.context)
|
|
|
|
self.assertFalse(result["valid"], "Should detect case-insensitive company names")
|
|
|
|
def test_empty_context(self):
|
|
"""Test handling of empty context."""
|
|
empty_context = {}
|
|
context_str = self.isolator._format_context(empty_context)
|
|
|
|
# Should not crash, just return empty sections
|
|
self.assertIsInstance(context_str, str)
|
|
|
|
def test_partial_context(self):
|
|
"""Test handling of partial context (missing sections)."""
|
|
partial_context = {
|
|
"market_data": {
|
|
"close": 100.0
|
|
}
|
|
}
|
|
|
|
context_str = self.isolator._format_context(partial_context)
|
|
|
|
self.assertIn("MARKET DATA", context_str)
|
|
self.assertNotIn("NEWS SUMMARY", context_str)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Run tests
|
|
unittest.main(verbosity=2)
|