TradingAgents/tests/test_output_parser.py

145 lines
5.2 KiB
Python

"""Tests for structured output parsing with retry on malformed output."""
import unittest
from pydantic import ValidationError
from tradingagents.agents.output_parser import StructuredOutputParser, validate_agent_output
from tradingagents.agents.schemas import (
AnalystReport,
PortfolioDecision,
RiskAssessment,
TraderDecision,
extract_fields,
)
VALID_ANALYST_JSON = (
'{"summary": "Stock is up", "detailed_analysis": "Strong earnings beat",'
' "key_points": ["Revenue up 20%"], "confidence": 0.85}'
)
VALID_TRADER_JSON = (
'{"action": "Buy", "reasoning": "Bullish trend", "confidence": 0.9,'
' "price_target": 150.0, "stop_loss": 130.0}'
)
VALID_RISK_JSON = (
'{"stance": "Cautious", "argument": "Volatility is high",'
' "risk_factors": ["Market downturn"], "confidence": 0.6}'
)
VALID_PORTFOLIO_JSON = (
'{"rating": "Buy", "executive_summary": "Strong buy",'
' "investment_thesis": "Solid fundamentals", "confidence": 0.8,'
' "price_target": 200.0, "time_horizon": "6 months"}'
)
class TestStructuredOutputParser(unittest.TestCase):
"""Test parse, retry, and extract_fields for all schema types."""
def test_parse_valid_json(self):
parser = StructuredOutputParser(AnalystReport)
result = parser.parse(VALID_ANALYST_JSON)
self.assertIsInstance(result, AnalystReport)
self.assertEqual(result.confidence, 0.85)
def test_parse_json_in_code_fence(self):
text = f"```json\n{VALID_TRADER_JSON}\n```"
parser = StructuredOutputParser(TraderDecision)
result = parser.parse(text)
self.assertEqual(result.action.value, "Buy")
def test_parse_malformed_raises(self):
parser = StructuredOutputParser(AnalystReport)
with self.assertRaises((ValidationError, Exception)):
parser.parse("This is not JSON at all")
def test_parse_missing_required_field_raises(self):
parser = StructuredOutputParser(TraderDecision)
# Missing 'action' and 'reasoning'
with self.assertRaises((ValidationError, Exception)):
parser.parse('{"confidence": 0.5}')
def test_parse_confidence_out_of_range_raises(self):
parser = StructuredOutputParser(AnalystReport)
bad = '{"summary": "x", "detailed_analysis": "x", "key_points": [], "confidence": 1.5}'
with self.assertRaises((ValidationError, Exception)):
parser.parse(bad)
def test_retry_recovers_from_malformed_output(self):
"""Malformed first response → retry → valid JSON → success."""
parser = StructuredOutputParser(AnalystReport)
calls: list[str] = []
def fake_llm(prompt: str) -> str:
calls.append(prompt)
return VALID_ANALYST_JSON
result = parser.parse_with_retry("not json", fake_llm, max_retries=2)
self.assertIsInstance(result, AnalystReport)
self.assertEqual(len(calls), 1) # one retry was needed
def test_retry_exhausted_raises(self):
"""All retries return garbage → raises."""
parser = StructuredOutputParser(TraderDecision)
def always_bad(prompt: str) -> str:
return "still not valid"
with self.assertRaises(Exception):
parser.parse_with_retry("bad", always_bad, max_retries=2)
def test_retry_second_attempt_succeeds(self):
"""First retry still bad, second retry returns valid JSON."""
parser = StructuredOutputParser(RiskAssessment)
attempt = {"n": 0}
def llm(prompt: str) -> str:
attempt["n"] += 1
if attempt["n"] < 2:
return "still broken"
return VALID_RISK_JSON
result = parser.parse_with_retry("garbage", llm, max_retries=2)
self.assertIsInstance(result, RiskAssessment)
self.assertEqual(attempt["n"], 2)
class TestValidateAgentOutput(unittest.TestCase):
"""Test the convenience wrapper used by agent nodes."""
def test_valid_output_returns_model_and_fields(self):
model, fields = validate_agent_output(VALID_PORTFOLIO_JSON, PortfolioDecision)
self.assertIsInstance(model, PortfolioDecision)
self.assertEqual(fields["rating"], "Buy")
self.assertIn("confidence", fields)
def test_invalid_output_without_llm_returns_none(self):
model, fields = validate_agent_output("garbage", AnalystReport, llm=None)
self.assertIsNone(model)
self.assertEqual(fields, {})
class TestExtractFields(unittest.TestCase):
"""Test structured field extraction from validated models."""
def test_analyst_extract(self):
m = AnalystReport.model_validate_json(VALID_ANALYST_JSON)
f = extract_fields(m)
self.assertIn("confidence", f)
self.assertIn("key_points", f)
self.assertNotIn("summary", f) # text fields excluded
def test_trader_extract_enum_to_str(self):
m = TraderDecision.model_validate_json(VALID_TRADER_JSON)
f = extract_fields(m)
self.assertEqual(f["action"], "Buy") # enum → string
def test_none_values_omitted(self):
m = TraderDecision(action="Buy", reasoning="bullish trend", confidence=0.5)
f = extract_fields(m)
self.assertNotIn("price_target", f)
self.assertNotIn("stop_loss", f)