TradingAgents/tradingagents/utils/json_retry.py

253 lines
7.4 KiB
Python

"""
JSON Retry Loop - Enforce Schema Compliance
If LLM outputs text instead of JSON, retry with error message.
Max 2 retries before hard failure.
"""
from typing import Type, TypeVar, Optional, Callable
from pydantic import BaseModel, ValidationError
import json
import time
T = TypeVar('T', bound=BaseModel)
class JSONRetryLoop:
"""
Enforce JSON schema compliance with retry mechanism.
If LLM outputs invalid JSON or violates schema, retry with error feedback.
"""
def __init__(self, max_retries: int = 2):
"""
Initialize retry loop.
Args:
max_retries: Maximum retry attempts (default 2)
"""
self.max_retries = max_retries
self.retry_stats = {
"total_calls": 0,
"successful_first_try": 0,
"successful_after_retry": 0,
"total_failures": 0
}
def invoke_with_retry(
self,
llm_callable: Callable,
schema: Type[T],
prompt: str,
context: dict
) -> tuple[Optional[T], dict]:
"""
Invoke LLM with automatic retry on schema violation.
Args:
llm_callable: Function that calls LLM (e.g., llm.invoke)
schema: Pydantic schema class
prompt: Initial prompt
context: Context dict for prompt formatting
Returns:
(parsed_output, metadata) where metadata contains retry info
"""
self.retry_stats["total_calls"] += 1
metadata = {
"attempts": 0,
"errors": [],
"latency": 0.0
}
start_time = time.time()
for attempt in range(self.max_retries + 1):
metadata["attempts"] = attempt + 1
try:
# Invoke LLM
if attempt == 0:
# First attempt: use original prompt
response = llm_callable(prompt.format(**context))
else:
# Retry: add error feedback
retry_prompt = self._build_retry_prompt(
prompt, context, metadata["errors"][-1]
)
response = llm_callable(retry_prompt)
# Extract JSON from response
json_str = self._extract_json(response.content)
# Parse JSON
json_data = json.loads(json_str)
# Validate against schema
parsed_output = schema(**json_data)
# Success!
metadata["latency"] = time.time() - start_time
if attempt == 0:
self.retry_stats["successful_first_try"] += 1
else:
self.retry_stats["successful_after_retry"] += 1
return parsed_output, metadata
except json.JSONDecodeError as e:
error_msg = f"Invalid JSON: {str(e)}"
metadata["errors"].append(error_msg)
except ValidationError as e:
error_msg = f"Schema validation failed: {str(e)}"
metadata["errors"].append(error_msg)
except Exception as e:
error_msg = f"Unexpected error: {str(e)}"
metadata["errors"].append(error_msg)
# All retries exhausted
self.retry_stats["total_failures"] += 1
metadata["latency"] = time.time() - start_time
return None, metadata
def _extract_json(self, text: str) -> str:
"""
Extract JSON from LLM response.
Handles cases where LLM wraps JSON in markdown code blocks.
"""
# Remove markdown code blocks
if "```json" in text:
start = text.find("```json") + 7
end = text.find("```", start)
return text[start:end].strip()
elif "```" in text:
start = text.find("```") + 3
end = text.find("```", start)
return text[start:end].strip()
# Try to find JSON object
if "{" in text and "}" in text:
start = text.find("{")
end = text.rfind("}") + 1
return text[start:end]
return text.strip()
def _build_retry_prompt(
self,
original_prompt: str,
context: dict,
error_msg: str
) -> str:
"""
Build retry prompt with error feedback.
Args:
original_prompt: Original prompt template
context: Context dict
error_msg: Error message from previous attempt
Returns:
Retry prompt with error feedback
"""
retry_instruction = f"""
CRITICAL ERROR: Your previous response failed validation.
ERROR: {error_msg}
You MUST output valid JSON matching the required schema. Do NOT output:
- Markdown explanations
- Text before or after JSON
- Invalid JSON syntax
- Missing required fields
Try again. Output ONLY valid JSON.
---
{original_prompt}
"""
return retry_instruction.format(**context)
def get_stats(self) -> dict:
"""Get retry statistics."""
total = self.retry_stats["total_calls"]
if total == 0:
return self.retry_stats
return {
**self.retry_stats,
"first_try_success_rate": self.retry_stats["successful_first_try"] / total,
"overall_success_rate": (
self.retry_stats["successful_first_try"] +
self.retry_stats["successful_after_retry"]
) / total,
"failure_rate": self.retry_stats["total_failures"] / total
}
# Example usage
if __name__ == "__main__":
from tradingagents.schemas.agent_schemas import AnalystOutput
# Mock LLM callable
class MockLLM:
def __init__(self, responses):
self.responses = responses
self.call_count = 0
def invoke(self, prompt):
response = self.responses[self.call_count]
self.call_count += 1
class Response:
def __init__(self, content):
self.content = content
return Response(response)
# Test: First attempt fails (invalid JSON), second succeeds
responses = [
"This is just text, not JSON", # First attempt fails
'''```json
{
"analyst_type": "market",
"key_findings": ["Finding 1", "Finding 2", "Finding 3"],
"signal": "BUY",
"confidence": 0.8,
"reasoning": "Strong technical indicators suggest bullish momentum with volume confirmation."
}
```''' # Second attempt succeeds
]
mock_llm = MockLLM(responses)
retry_loop = JSONRetryLoop(max_retries=2)
prompt = "Analyze the market and output JSON"
context = {}
result, metadata = retry_loop.invoke_with_retry(
mock_llm.invoke,
AnalystOutput,
prompt,
context
)
print(f"Attempts: {metadata['attempts']}")
print(f"Errors: {metadata['errors']}")
print(f"Success: {result is not None}")
if result:
print(f"\nParsed output:")
print(result.json(indent=2))
print(f"\nRetry stats:")
print(retry_loop.get_stats())