253 lines
7.4 KiB
Python
253 lines
7.4 KiB
Python
"""
|
|
JSON Retry Loop - Enforce Schema Compliance
|
|
|
|
If LLM outputs text instead of JSON, retry with error message.
|
|
Max 2 retries before hard failure.
|
|
"""
|
|
|
|
from typing import Type, TypeVar, Optional, Callable
|
|
from pydantic import BaseModel, ValidationError
|
|
import json
|
|
import time
|
|
|
|
T = TypeVar('T', bound=BaseModel)
|
|
|
|
|
|
class JSONRetryLoop:
|
|
"""
|
|
Enforce JSON schema compliance with retry mechanism.
|
|
|
|
If LLM outputs invalid JSON or violates schema, retry with error feedback.
|
|
"""
|
|
|
|
def __init__(self, max_retries: int = 2):
|
|
"""
|
|
Initialize retry loop.
|
|
|
|
Args:
|
|
max_retries: Maximum retry attempts (default 2)
|
|
"""
|
|
self.max_retries = max_retries
|
|
self.retry_stats = {
|
|
"total_calls": 0,
|
|
"successful_first_try": 0,
|
|
"successful_after_retry": 0,
|
|
"total_failures": 0
|
|
}
|
|
|
|
def invoke_with_retry(
|
|
self,
|
|
llm_callable: Callable,
|
|
schema: Type[T],
|
|
prompt: str,
|
|
context: dict
|
|
) -> tuple[Optional[T], dict]:
|
|
"""
|
|
Invoke LLM with automatic retry on schema violation.
|
|
|
|
Args:
|
|
llm_callable: Function that calls LLM (e.g., llm.invoke)
|
|
schema: Pydantic schema class
|
|
prompt: Initial prompt
|
|
context: Context dict for prompt formatting
|
|
|
|
Returns:
|
|
(parsed_output, metadata) where metadata contains retry info
|
|
"""
|
|
self.retry_stats["total_calls"] += 1
|
|
|
|
metadata = {
|
|
"attempts": 0,
|
|
"errors": [],
|
|
"latency": 0.0
|
|
}
|
|
|
|
start_time = time.time()
|
|
|
|
for attempt in range(self.max_retries + 1):
|
|
metadata["attempts"] = attempt + 1
|
|
|
|
try:
|
|
# Invoke LLM
|
|
if attempt == 0:
|
|
# First attempt: use original prompt
|
|
response = llm_callable(prompt.format(**context))
|
|
else:
|
|
# Retry: add error feedback
|
|
retry_prompt = self._build_retry_prompt(
|
|
prompt, context, metadata["errors"][-1]
|
|
)
|
|
response = llm_callable(retry_prompt)
|
|
|
|
# Extract JSON from response
|
|
json_str = self._extract_json(response.content)
|
|
|
|
# Parse JSON
|
|
json_data = json.loads(json_str)
|
|
|
|
# Validate against schema
|
|
parsed_output = schema(**json_data)
|
|
|
|
# Success!
|
|
metadata["latency"] = time.time() - start_time
|
|
|
|
if attempt == 0:
|
|
self.retry_stats["successful_first_try"] += 1
|
|
else:
|
|
self.retry_stats["successful_after_retry"] += 1
|
|
|
|
return parsed_output, metadata
|
|
|
|
except json.JSONDecodeError as e:
|
|
error_msg = f"Invalid JSON: {str(e)}"
|
|
metadata["errors"].append(error_msg)
|
|
|
|
except ValidationError as e:
|
|
error_msg = f"Schema validation failed: {str(e)}"
|
|
metadata["errors"].append(error_msg)
|
|
|
|
except Exception as e:
|
|
error_msg = f"Unexpected error: {str(e)}"
|
|
metadata["errors"].append(error_msg)
|
|
|
|
# All retries exhausted
|
|
self.retry_stats["total_failures"] += 1
|
|
metadata["latency"] = time.time() - start_time
|
|
|
|
return None, metadata
|
|
|
|
def _extract_json(self, text: str) -> str:
|
|
"""
|
|
Extract JSON from LLM response.
|
|
|
|
Handles cases where LLM wraps JSON in markdown code blocks.
|
|
"""
|
|
# Remove markdown code blocks
|
|
if "```json" in text:
|
|
start = text.find("```json") + 7
|
|
end = text.find("```", start)
|
|
return text[start:end].strip()
|
|
elif "```" in text:
|
|
start = text.find("```") + 3
|
|
end = text.find("```", start)
|
|
return text[start:end].strip()
|
|
|
|
# Try to find JSON object
|
|
if "{" in text and "}" in text:
|
|
start = text.find("{")
|
|
end = text.rfind("}") + 1
|
|
return text[start:end]
|
|
|
|
return text.strip()
|
|
|
|
def _build_retry_prompt(
|
|
self,
|
|
original_prompt: str,
|
|
context: dict,
|
|
error_msg: str
|
|
) -> str:
|
|
"""
|
|
Build retry prompt with error feedback.
|
|
|
|
Args:
|
|
original_prompt: Original prompt template
|
|
context: Context dict
|
|
error_msg: Error message from previous attempt
|
|
|
|
Returns:
|
|
Retry prompt with error feedback
|
|
"""
|
|
retry_instruction = f"""
|
|
CRITICAL ERROR: Your previous response failed validation.
|
|
|
|
ERROR: {error_msg}
|
|
|
|
You MUST output valid JSON matching the required schema. Do NOT output:
|
|
- Markdown explanations
|
|
- Text before or after JSON
|
|
- Invalid JSON syntax
|
|
- Missing required fields
|
|
|
|
Try again. Output ONLY valid JSON.
|
|
|
|
---
|
|
|
|
{original_prompt}
|
|
"""
|
|
return retry_instruction.format(**context)
|
|
|
|
def get_stats(self) -> dict:
|
|
"""Get retry statistics."""
|
|
total = self.retry_stats["total_calls"]
|
|
if total == 0:
|
|
return self.retry_stats
|
|
|
|
return {
|
|
**self.retry_stats,
|
|
"first_try_success_rate": self.retry_stats["successful_first_try"] / total,
|
|
"overall_success_rate": (
|
|
self.retry_stats["successful_first_try"] +
|
|
self.retry_stats["successful_after_retry"]
|
|
) / total,
|
|
"failure_rate": self.retry_stats["total_failures"] / total
|
|
}
|
|
|
|
|
|
# Example usage
|
|
if __name__ == "__main__":
|
|
from tradingagents.schemas.agent_schemas import AnalystOutput
|
|
|
|
# Mock LLM callable
|
|
class MockLLM:
|
|
def __init__(self, responses):
|
|
self.responses = responses
|
|
self.call_count = 0
|
|
|
|
def invoke(self, prompt):
|
|
response = self.responses[self.call_count]
|
|
self.call_count += 1
|
|
|
|
class Response:
|
|
def __init__(self, content):
|
|
self.content = content
|
|
|
|
return Response(response)
|
|
|
|
# Test: First attempt fails (invalid JSON), second succeeds
|
|
responses = [
|
|
"This is just text, not JSON", # First attempt fails
|
|
'''```json
|
|
{
|
|
"analyst_type": "market",
|
|
"key_findings": ["Finding 1", "Finding 2", "Finding 3"],
|
|
"signal": "BUY",
|
|
"confidence": 0.8,
|
|
"reasoning": "Strong technical indicators suggest bullish momentum with volume confirmation."
|
|
}
|
|
```''' # Second attempt succeeds
|
|
]
|
|
|
|
mock_llm = MockLLM(responses)
|
|
retry_loop = JSONRetryLoop(max_retries=2)
|
|
|
|
prompt = "Analyze the market and output JSON"
|
|
context = {}
|
|
|
|
result, metadata = retry_loop.invoke_with_retry(
|
|
mock_llm.invoke,
|
|
AnalystOutput,
|
|
prompt,
|
|
context
|
|
)
|
|
|
|
print(f"Attempts: {metadata['attempts']}")
|
|
print(f"Errors: {metadata['errors']}")
|
|
print(f"Success: {result is not None}")
|
|
|
|
if result:
|
|
print(f"\nParsed output:")
|
|
print(result.json(indent=2))
|
|
|
|
print(f"\nRetry stats:")
|
|
print(retry_loop.get_stats())
|