145 lines
5.2 KiB
Python
145 lines
5.2 KiB
Python
"""Tests for structured output parsing with retry on malformed output."""
|
|
|
|
import unittest
|
|
|
|
from pydantic import ValidationError
|
|
|
|
from tradingagents.agents.output_parser import StructuredOutputParser, validate_agent_output
|
|
from tradingagents.agents.schemas import (
|
|
AnalystReport,
|
|
PortfolioDecision,
|
|
RiskAssessment,
|
|
TraderDecision,
|
|
extract_fields,
|
|
)
|
|
|
|
|
|
VALID_ANALYST_JSON = (
|
|
'{"summary": "Stock is up", "detailed_analysis": "Strong earnings beat",'
|
|
' "key_points": ["Revenue up 20%"], "confidence": 0.85}'
|
|
)
|
|
|
|
VALID_TRADER_JSON = (
|
|
'{"action": "Buy", "reasoning": "Bullish trend", "confidence": 0.9,'
|
|
' "price_target": 150.0, "stop_loss": 130.0}'
|
|
)
|
|
|
|
VALID_RISK_JSON = (
|
|
'{"stance": "Cautious", "argument": "Volatility is high",'
|
|
' "risk_factors": ["Market downturn"], "confidence": 0.6}'
|
|
)
|
|
|
|
VALID_PORTFOLIO_JSON = (
|
|
'{"rating": "Buy", "executive_summary": "Strong buy",'
|
|
' "investment_thesis": "Solid fundamentals", "confidence": 0.8,'
|
|
' "price_target": 200.0, "time_horizon": "6 months"}'
|
|
)
|
|
|
|
|
|
class TestStructuredOutputParser(unittest.TestCase):
|
|
"""Test parse, retry, and extract_fields for all schema types."""
|
|
|
|
def test_parse_valid_json(self):
|
|
parser = StructuredOutputParser(AnalystReport)
|
|
result = parser.parse(VALID_ANALYST_JSON)
|
|
self.assertIsInstance(result, AnalystReport)
|
|
self.assertEqual(result.confidence, 0.85)
|
|
|
|
def test_parse_json_in_code_fence(self):
|
|
text = f"```json\n{VALID_TRADER_JSON}\n```"
|
|
parser = StructuredOutputParser(TraderDecision)
|
|
result = parser.parse(text)
|
|
self.assertEqual(result.action.value, "Buy")
|
|
|
|
def test_parse_malformed_raises(self):
|
|
parser = StructuredOutputParser(AnalystReport)
|
|
with self.assertRaises((ValidationError, Exception)):
|
|
parser.parse("This is not JSON at all")
|
|
|
|
def test_parse_missing_required_field_raises(self):
|
|
parser = StructuredOutputParser(TraderDecision)
|
|
# Missing 'action' and 'reasoning'
|
|
with self.assertRaises((ValidationError, Exception)):
|
|
parser.parse('{"confidence": 0.5}')
|
|
|
|
def test_parse_confidence_out_of_range_raises(self):
|
|
parser = StructuredOutputParser(AnalystReport)
|
|
bad = '{"summary": "x", "detailed_analysis": "x", "key_points": [], "confidence": 1.5}'
|
|
with self.assertRaises((ValidationError, Exception)):
|
|
parser.parse(bad)
|
|
|
|
def test_retry_recovers_from_malformed_output(self):
|
|
"""Malformed first response → retry → valid JSON → success."""
|
|
parser = StructuredOutputParser(AnalystReport)
|
|
calls: list[str] = []
|
|
|
|
def fake_llm(prompt: str) -> str:
|
|
calls.append(prompt)
|
|
return VALID_ANALYST_JSON
|
|
|
|
result = parser.parse_with_retry("not json", fake_llm, max_retries=2)
|
|
self.assertIsInstance(result, AnalystReport)
|
|
self.assertEqual(len(calls), 1) # one retry was needed
|
|
|
|
def test_retry_exhausted_raises(self):
|
|
"""All retries return garbage → raises."""
|
|
parser = StructuredOutputParser(TraderDecision)
|
|
|
|
def always_bad(prompt: str) -> str:
|
|
return "still not valid"
|
|
|
|
with self.assertRaises(Exception):
|
|
parser.parse_with_retry("bad", always_bad, max_retries=2)
|
|
|
|
def test_retry_second_attempt_succeeds(self):
|
|
"""First retry still bad, second retry returns valid JSON."""
|
|
parser = StructuredOutputParser(RiskAssessment)
|
|
attempt = {"n": 0}
|
|
|
|
def llm(prompt: str) -> str:
|
|
attempt["n"] += 1
|
|
if attempt["n"] < 2:
|
|
return "still broken"
|
|
return VALID_RISK_JSON
|
|
|
|
result = parser.parse_with_retry("garbage", llm, max_retries=2)
|
|
self.assertIsInstance(result, RiskAssessment)
|
|
self.assertEqual(attempt["n"], 2)
|
|
|
|
|
|
class TestValidateAgentOutput(unittest.TestCase):
|
|
"""Test the convenience wrapper used by agent nodes."""
|
|
|
|
def test_valid_output_returns_model_and_fields(self):
|
|
model, fields = validate_agent_output(VALID_PORTFOLIO_JSON, PortfolioDecision)
|
|
self.assertIsInstance(model, PortfolioDecision)
|
|
self.assertEqual(fields["rating"], "Buy")
|
|
self.assertIn("confidence", fields)
|
|
|
|
def test_invalid_output_without_llm_returns_none(self):
|
|
model, fields = validate_agent_output("garbage", AnalystReport, llm=None)
|
|
self.assertIsNone(model)
|
|
self.assertEqual(fields, {})
|
|
|
|
|
|
class TestExtractFields(unittest.TestCase):
|
|
"""Test structured field extraction from validated models."""
|
|
|
|
def test_analyst_extract(self):
|
|
m = AnalystReport.model_validate_json(VALID_ANALYST_JSON)
|
|
f = extract_fields(m)
|
|
self.assertIn("confidence", f)
|
|
self.assertIn("key_points", f)
|
|
self.assertNotIn("summary", f) # text fields excluded
|
|
|
|
def test_trader_extract_enum_to_str(self):
|
|
m = TraderDecision.model_validate_json(VALID_TRADER_JSON)
|
|
f = extract_fields(m)
|
|
self.assertEqual(f["action"], "Buy") # enum → string
|
|
|
|
def test_none_values_omitted(self):
|
|
m = TraderDecision(action="Buy", reasoning="bullish trend", confidence=0.5)
|
|
f = extract_fields(m)
|
|
self.assertNotIn("price_target", f)
|
|
self.assertNotIn("stop_loss", f)
|