TradingAgents/tradingagents/dataflows/rag_isolator.py

273 lines
9.6 KiB
Python

"""
RAG Isolator - Strict Context Enforcement
Forces LLMs to answer ONLY from provided context, preventing use of pre-trained knowledge.
"""
from typing import Dict, List, Any, Optional
from langchain.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage, HumanMessage
class RAGIsolator:
"""
Enforce strict RAG (Retrieval-Augmented Generation) to prevent knowledge contamination.
LLMs must answer ONLY from provided context, not from training data.
"""
def __init__(self, strict_mode: bool = True):
"""
Initialize RAG isolator.
Args:
strict_mode: If True, explicitly forbid use of pre-trained knowledge
"""
self.strict_mode = strict_mode
def create_isolated_prompt(
self,
query: str,
context: Dict[str, Any],
system_role: str = "financial analyst"
) -> ChatPromptTemplate:
"""
Create a prompt that enforces strict RAG isolation.
Args:
query: The question to answer
context: Structured context data (market data, news, fundamentals)
system_role: Role description for the agent
Returns:
ChatPromptTemplate with strict RAG enforcement
"""
# Build context string from structured data
context_str = self._format_context(context)
if self.strict_mode:
system_message = f"""You are a {system_role}. You must answer questions using ONLY the information provided in the CONTEXT section below.
CRITICAL RULES:
1. DO NOT use any knowledge from your training data
2. DO NOT make assumptions about companies, products, or events
3. If the CONTEXT does not contain the information needed to answer, respond with "INSUFFICIENT DATA"
4. DO NOT identify companies by price levels, volatility patterns, or other indirect signals
5. Treat all data as anonymous - you are analyzing ASSET_XXX, not real companies
CONTEXT:
{context_str}
If you cannot answer from the CONTEXT alone, you MUST respond: "INSUFFICIENT DATA: [explain what information is missing]"
"""
else:
system_message = f"""You are a {system_role}. Use the following context to answer questions.
CONTEXT:
{context_str}
"""
prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("human", "{query}")
])
return prompt
def _format_context(self, context: Dict[str, Any]) -> str:
"""
Format structured context into readable text.
Args:
context: Dictionary with market data, news, fundamentals, etc.
Returns:
Formatted context string
"""
sections = []
# Market Data Section
if "market_data" in context:
market_data = context["market_data"]
sections.append("=== MARKET DATA ===")
sections.append(f"Current Price Index: {market_data.get('close', 'N/A')}")
sections.append(f"Volume: {market_data.get('volume', 'N/A')}")
if "indicators" in market_data:
sections.append("\nTechnical Indicators:")
for indicator, value in market_data["indicators"].items():
sections.append(f" {indicator}: {value}")
# News Section
if "news" in context:
sections.append("\n=== NEWS SUMMARY ===")
for i, article in enumerate(context["news"][:5], 1): # Limit to 5 articles
sections.append(f"{i}. {article.get('summary', article.get('title', 'N/A'))}")
# Fundamentals Section
if "fundamentals" in context:
fundamentals = context["fundamentals"]
sections.append("\n=== FUNDAMENTAL DATA ===")
sections.append(f"Revenue Growth: {fundamentals.get('revenue_growth', 'N/A')}")
sections.append(f"Earnings: {fundamentals.get('earnings', 'N/A')}")
sections.append(f"Debt/Equity: {fundamentals.get('debt_to_equity', 'N/A')}")
# Historical Performance
if "historical" in context:
sections.append("\n=== HISTORICAL PERFORMANCE ===")
hist = context["historical"]
sections.append(f"1-Month Return: {hist.get('1m_return', 'N/A')}")
sections.append(f"3-Month Return: {hist.get('3m_return', 'N/A')}")
sections.append(f"6-Month Return: {hist.get('6m_return', 'N/A')}")
return "\n".join(sections)
def validate_response(self, response: str, context: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate that LLM response only uses information from context.
Args:
response: LLM's response
context: The context that was provided
Returns:
{
"valid": bool,
"violations": List[str],
"confidence": float
}
"""
violations = []
# Check for company name leakage
company_indicators = [
"Apple", "Microsoft", "Google", "Amazon", "Meta", "Tesla",
"Nvidia", "AMD", "Intel", "Oracle", "Salesforce"
]
for company in company_indicators:
if company.lower() in response.lower():
violations.append(f"Mentioned real company name: {company}")
# Check for product name leakage
product_indicators = [
"iPhone", "Windows", "Android", "Azure", "AWS",
"GeForce", "RTX", "H100", "A100"
]
for product in product_indicators:
if product.lower() in response.lower():
violations.append(f"Mentioned real product name: {product}")
# CRITICAL: Check for currency symbols (immediate hallucination)
# If context uses normalized values, ANY currency symbol is a leak
import re
currency_symbols = re.findall(r'[\$€£¥₹]', response)
if currency_symbols:
violations.append(f"HALLUCINATION: Used currency symbols {set(currency_symbols)} (context uses normalized index)")
# Check for absolute dollar amounts (3+ digits with $)
# This catches "$480" but not "$1.20" (which could be earnings per share)
absolute_prices = re.findall(r'\$\d{3,}', response)
if absolute_prices:
violations.append(f"Mentioned absolute dollar prices: {absolute_prices}")
# Check for "I know" or "based on my knowledge" phrases
knowledge_phrases = [
"i know", "as i know", "from my knowledge",
"based on my training", "historically", "typically"
]
for phrase in knowledge_phrases:
if phrase in response.lower():
violations.append(f"Used pre-trained knowledge phrase: '{phrase}'")
valid = len(violations) == 0
confidence = 1.0 - (len(violations) * 0.2) # Reduce confidence per violation
return {
"valid": valid,
"violations": violations,
"confidence": max(0.0, confidence)
}
def create_fact_grounded_prompt(
self,
query: str,
facts: List[str],
allow_inference: bool = False
) -> str:
"""
Create a prompt that grounds LLM in specific facts.
Args:
query: Question to answer
facts: List of factual statements
allow_inference: Whether to allow logical inference from facts
Returns:
Formatted prompt string
"""
facts_str = "\n".join([f"{i+1}. {fact}" for i, fact in enumerate(facts)])
if allow_inference:
instruction = "You may make logical inferences from these facts, but clearly state when you are inferring."
else:
instruction = "Answer using ONLY these facts. Do not infer or extrapolate."
prompt = f"""FACTS:
{facts_str}
QUESTION: {query}
INSTRUCTION: {instruction}
ANSWER:"""
return prompt
# Example usage
if __name__ == "__main__":
isolator = RAGIsolator(strict_mode=True)
# Create isolated context
context = {
"market_data": {
"close": 102.5,
"volume": 50000000,
"indicators": {
"RSI": 45.2,
"MACD": 0.8,
"50_SMA": 100.3
}
},
"news": [
{"summary": "Company ASSET_042 reported quarterly earnings"},
{"summary": "Product A sales exceeded expectations"}
],
"fundamentals": {
"revenue_growth": 0.05,
"earnings": 1.2,
"debt_to_equity": 0.3
}
}
# Create prompt
query = "Should I buy this asset?"
prompt = isolator.create_isolated_prompt(query, context)
print("=== ISOLATED PROMPT ===")
print(prompt.format(query=query))
# Test response validation
print("\n=== RESPONSE VALIDATION ===")
# Good response (only uses context)
good_response = "Based on the RSI of 45.2 and positive revenue growth of 5%, the asset shows moderate strength."
result = isolator.validate_response(good_response, context)
print(f"Good response valid: {result['valid']}")
# Bad response (uses pre-trained knowledge)
bad_response = "This is clearly Apple based on the price level. iPhone sales are strong."
result = isolator.validate_response(bad_response, context)
print(f"Bad response valid: {result['valid']}")
print(f"Violations: {result['violations']}")