This commit is contained in:
claytonbrown 2026-04-20 17:44:09 +05:00 committed by GitHub
commit 64480eea27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 462 additions and 0 deletions

144
tests/test_output_parser.py Normal file
View File

@ -0,0 +1,144 @@
"""Tests for structured output parsing with retry on malformed output."""
import unittest
from pydantic import ValidationError
from tradingagents.agents.output_parser import StructuredOutputParser, validate_agent_output
from tradingagents.agents.schemas import (
AnalystReport,
PortfolioDecision,
RiskAssessment,
TraderDecision,
extract_fields,
)
VALID_ANALYST_JSON = (
'{"summary": "Stock is up", "detailed_analysis": "Strong earnings beat",'
' "key_points": ["Revenue up 20%"], "confidence": 0.85}'
)
VALID_TRADER_JSON = (
'{"action": "Buy", "reasoning": "Bullish trend", "confidence": 0.9,'
' "price_target": 150.0, "stop_loss": 130.0}'
)
VALID_RISK_JSON = (
'{"stance": "Cautious", "argument": "Volatility is high",'
' "risk_factors": ["Market downturn"], "confidence": 0.6}'
)
VALID_PORTFOLIO_JSON = (
'{"rating": "Buy", "executive_summary": "Strong buy",'
' "investment_thesis": "Solid fundamentals", "confidence": 0.8,'
' "price_target": 200.0, "time_horizon": "6 months"}'
)
class TestStructuredOutputParser(unittest.TestCase):
"""Test parse, retry, and extract_fields for all schema types."""
def test_parse_valid_json(self):
parser = StructuredOutputParser(AnalystReport)
result = parser.parse(VALID_ANALYST_JSON)
self.assertIsInstance(result, AnalystReport)
self.assertEqual(result.confidence, 0.85)
def test_parse_json_in_code_fence(self):
text = f"```json\n{VALID_TRADER_JSON}\n```"
parser = StructuredOutputParser(TraderDecision)
result = parser.parse(text)
self.assertEqual(result.action.value, "Buy")
def test_parse_malformed_raises(self):
parser = StructuredOutputParser(AnalystReport)
with self.assertRaises((ValidationError, Exception)):
parser.parse("This is not JSON at all")
def test_parse_missing_required_field_raises(self):
parser = StructuredOutputParser(TraderDecision)
# Missing 'action' and 'reasoning'
with self.assertRaises((ValidationError, Exception)):
parser.parse('{"confidence": 0.5}')
def test_parse_confidence_out_of_range_raises(self):
parser = StructuredOutputParser(AnalystReport)
bad = '{"summary": "x", "detailed_analysis": "x", "key_points": [], "confidence": 1.5}'
with self.assertRaises((ValidationError, Exception)):
parser.parse(bad)
def test_retry_recovers_from_malformed_output(self):
"""Malformed first response → retry → valid JSON → success."""
parser = StructuredOutputParser(AnalystReport)
calls: list[str] = []
def fake_llm(prompt: str) -> str:
calls.append(prompt)
return VALID_ANALYST_JSON
result = parser.parse_with_retry("not json", fake_llm, max_retries=2)
self.assertIsInstance(result, AnalystReport)
self.assertEqual(len(calls), 1) # one retry was needed
def test_retry_exhausted_raises(self):
"""All retries return garbage → raises."""
parser = StructuredOutputParser(TraderDecision)
def always_bad(prompt: str) -> str:
return "still not valid"
with self.assertRaises(Exception):
parser.parse_with_retry("bad", always_bad, max_retries=2)
def test_retry_second_attempt_succeeds(self):
"""First retry still bad, second retry returns valid JSON."""
parser = StructuredOutputParser(RiskAssessment)
attempt = {"n": 0}
def llm(prompt: str) -> str:
attempt["n"] += 1
if attempt["n"] < 2:
return "still broken"
return VALID_RISK_JSON
result = parser.parse_with_retry("garbage", llm, max_retries=2)
self.assertIsInstance(result, RiskAssessment)
self.assertEqual(attempt["n"], 2)
class TestValidateAgentOutput(unittest.TestCase):
"""Test the convenience wrapper used by agent nodes."""
def test_valid_output_returns_model_and_fields(self):
model, fields = validate_agent_output(VALID_PORTFOLIO_JSON, PortfolioDecision)
self.assertIsInstance(model, PortfolioDecision)
self.assertEqual(fields["rating"], "Buy")
self.assertIn("confidence", fields)
def test_invalid_output_without_llm_returns_none(self):
model, fields = validate_agent_output("garbage", AnalystReport, llm=None)
self.assertIsNone(model)
self.assertEqual(fields, {})
class TestExtractFields(unittest.TestCase):
"""Test structured field extraction from validated models."""
def test_analyst_extract(self):
m = AnalystReport.model_validate_json(VALID_ANALYST_JSON)
f = extract_fields(m)
self.assertIn("confidence", f)
self.assertIn("key_points", f)
self.assertNotIn("summary", f) # text fields excluded
def test_trader_extract_enum_to_str(self):
m = TraderDecision.model_validate_json(VALID_TRADER_JSON)
f = extract_fields(m)
self.assertEqual(f["action"], "Buy") # enum → string
def test_none_values_omitted(self):
m = TraderDecision(action="Buy", reasoning="bullish trend", confidence=0.5)
f = extract_fields(m)
self.assertNotIn("price_target", f)
self.assertNotIn("stop_loss", f)

View File

@ -17,6 +17,9 @@ from .risk_mgmt.neutral_debator import create_neutral_debator
from .managers.research_manager import create_research_manager
from .managers.portfolio_manager import create_portfolio_manager
from .output_parser import StructuredOutputParser, validate_agent_output
from .schemas import extract_fields
from .trader.trader import create_trader
__all__ = [
@ -37,4 +40,7 @@ __all__ = [
"create_conservative_debator",
"create_social_media_analyst",
"create_trader",
"StructuredOutputParser",
"validate_agent_output",
"extract_fields",
]

View File

@ -8,6 +8,8 @@ from tradingagents.agents.utils.agent_utils import (
get_insider_transactions,
get_language_instruction,
)
from tradingagents.agents.output_parser import validate_agent_output
from tradingagents.agents.schemas import AnalystReport
from tradingagents.dataflows.config import get_config
@ -60,6 +62,9 @@ def create_fundamentals_analyst(llm):
if len(result.tool_calls) == 0:
report = result.content
model, _ = validate_agent_output(report, AnalystReport, llm)
if model:
report = model.summary or report
return {
"messages": [result],

View File

@ -5,6 +5,8 @@ from tradingagents.agents.utils.agent_utils import (
get_language_instruction,
get_stock_data,
)
from tradingagents.agents.output_parser import validate_agent_output
from tradingagents.agents.schemas import AnalystReport
from tradingagents.dataflows.config import get_config
@ -79,6 +81,9 @@ Volume-Based Indicators:
if len(result.tool_calls) == 0:
report = result.content
model, _ = validate_agent_output(report, AnalystReport, llm)
if model:
report = model.summary or report
return {
"messages": [result],

View File

@ -5,6 +5,8 @@ from tradingagents.agents.utils.agent_utils import (
get_language_instruction,
get_news,
)
from tradingagents.agents.output_parser import validate_agent_output
from tradingagents.agents.schemas import AnalystReport
from tradingagents.dataflows.config import get_config
@ -53,6 +55,9 @@ def create_news_analyst(llm):
if len(result.tool_calls) == 0:
report = result.content
model, _ = validate_agent_output(report, AnalystReport, llm)
if model:
report = model.summary or report
return {
"messages": [result],

View File

@ -1,5 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
from tradingagents.agents.output_parser import validate_agent_output
from tradingagents.agents.schemas import AnalystReport
from tradingagents.dataflows.config import get_config
@ -48,6 +50,9 @@ def create_social_media_analyst(llm):
if len(result.tool_calls) == 0:
report = result.content
model, _ = validate_agent_output(report, AnalystReport, llm)
if model:
report = model.summary or report
return {
"messages": [result],

View File

@ -1,4 +1,6 @@
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction
from tradingagents.agents.output_parser import validate_agent_output
from tradingagents.agents.schemas import PortfolioDecision
def create_portfolio_manager(llm, memory):
@ -56,6 +58,8 @@ Be decisive and ground every conclusion in specific evidence from the analysts.{
response = llm.invoke(prompt)
model, _ = validate_agent_output(response.content, PortfolioDecision, llm)
new_risk_debate_state = {
"judge_decision": response.content,
"history": risk_debate_state["history"],

View File

@ -1,5 +1,7 @@
from tradingagents.agents.utils.agent_utils import build_instrument_context
from tradingagents.agents.output_parser import validate_agent_output
from tradingagents.agents.schemas import TraderDecision
def create_research_manager(llm, memory):
@ -41,6 +43,8 @@ Debate History:
{history}"""
response = llm.invoke(prompt)
model, _ = validate_agent_output(response.content, TraderDecision, llm)
new_investment_debate_state = {
"judge_decision": response.content,
"history": investment_debate_state.get("history", ""),

View File

@ -0,0 +1,170 @@
"""Output parser that validates LLM responses against Pydantic schemas."""
import json
import logging
import re
from collections.abc import Callable
from typing import Any, TypeVar
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel, ValidationError
T = TypeVar("T", bound=BaseModel)
logger = logging.getLogger(__name__)
_RETRY_PROMPT = (
"Your previous response could not be parsed. Error:\n{error}\n\n"
"Please respond ONLY with valid JSON matching this schema:\n{instructions}\n\n"
"Previous (invalid) response:\n{previous}"
)
MAX_RETRIES = 2
class StructuredOutputParser:
"""Validates LLM text output against a Pydantic model.
Usage:
parser = StructuredOutputParser(AnalystReport)
instructions = parser.get_format_instructions() # inject into prompt
result = parser.parse(llm_response_text) # returns AnalystReport or raises
With retry:
result = parser.parse_with_retry(llm_response_text, llm_caller)
"""
def __init__(self, schema: type[T]) -> None:
self.schema = schema
self._langchain_parser = PydanticOutputParser(pydantic_object=schema)
def get_format_instructions(self) -> str:
"""Return formatting instructions to embed in the LLM prompt."""
return self._langchain_parser.get_format_instructions()
def parse(self, text: str) -> T:
"""Parse LLM text into the Pydantic model.
Tries JSON extraction first, then falls back to langchain parser.
Raises:
ValidationError: If the output doesn't match the schema.
"""
# Try to extract JSON from markdown code fences or raw JSON
json_str = self._extract_json(text)
if json_str is not None:
try:
data = json.loads(json_str)
return self.schema.model_validate(data)
except (json.JSONDecodeError, ValidationError):
pass
# Fallback: let langchain parser try
try:
return self._langchain_parser.parse(text)
except Exception as e:
# Re-raise as ValidationError for consistent handling
raise ValidationError.from_exception_data(
title=self.schema.__name__,
line_errors=[
{
"type": "value_error",
"loc": (),
"msg": f"Failed to parse LLM output: {e}",
"input": text[:500],
"ctx": {"error": str(e)},
}
],
) from e
def parse_with_retry(
self,
text: str,
llm_caller: Callable[[str], str],
max_retries: int = MAX_RETRIES,
) -> T:
"""Parse with automatic retry on validation failure.
On failure, sends the error and format instructions back to the LLM
via *llm_caller* (a callable that accepts a prompt string and returns
the LLM's text response).
Args:
text: Initial LLM response text to parse.
llm_caller: ``fn(prompt) -> response_text`` used for retries.
max_retries: Maximum number of retry attempts (default 2).
Returns:
Validated Pydantic model instance.
Raises:
ValidationError: If all retries are exhausted.
"""
last_error: Exception | None = None
current_text = text
for attempt in range(1 + max_retries):
try:
return self.parse(current_text)
except (ValidationError, Exception) as exc:
last_error = exc
if attempt < max_retries:
logger.warning(
"Validation failed for %s (attempt %d/%d): %s",
self.schema.__name__,
attempt + 1,
1 + max_retries,
exc,
)
retry_prompt = _RETRY_PROMPT.format(
error=str(exc),
instructions=self.get_format_instructions(),
previous=current_text[:1000],
)
current_text = llm_caller(retry_prompt)
# All retries exhausted — raise the last error
raise last_error # type: ignore[misc]
@staticmethod
def _extract_json(text: str) -> str | None:
"""Extract JSON from markdown code fences or find raw JSON object."""
# Match ```json ... ``` or ``` ... ```
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL)
if match:
return match.group(1).strip()
# Try to find a raw JSON object (non-greedy to avoid spanning multiple blocks)
match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", text)
if match:
return match.group(0)
return None
def validate_agent_output(
text: str,
schema: type[T],
llm: Any | None = None,
) -> tuple[T | None, dict]:
"""Validate agent output against a schema, with optional LLM retry.
Returns (model_instance, extracted_fields) on success,
or (None, {}) on failure (graceful degradation).
"""
from tradingagents.agents.schemas import extract_fields
parser = StructuredOutputParser(schema)
def _llm_caller(prompt: str) -> str:
return llm.invoke(prompt).content
try:
if llm is not None:
model = parser.parse_with_retry(text, _llm_caller)
else:
model = parser.parse(text)
return model, extract_fields(model)
except Exception:
logger.warning("Schema validation failed for %s, passing raw text through", schema.__name__)
return None, {}

View File

@ -1,3 +1,5 @@
from tradingagents.agents.output_parser import validate_agent_output
from tradingagents.agents.schemas import RiskAssessment
def create_bear_researcher(llm, memory):
@ -43,6 +45,8 @@ Use this information to deliver a compelling bear argument, refute the bull's cl
response = llm.invoke(prompt)
model, _ = validate_agent_output(response.content, RiskAssessment, llm)
argument = f"Bear Analyst: {response.content}"
new_investment_debate_state = {

View File

@ -1,3 +1,5 @@
from tradingagents.agents.output_parser import validate_agent_output
from tradingagents.agents.schemas import RiskAssessment
def create_bull_researcher(llm, memory):
@ -41,6 +43,8 @@ Use this information to deliver a compelling bull argument, refute the bear's co
response = llm.invoke(prompt)
model, _ = validate_agent_output(response.content, RiskAssessment, llm)
argument = f"Bull Analyst: {response.content}"
new_investment_debate_state = {

View File

@ -1,3 +1,5 @@
from tradingagents.agents.output_parser import validate_agent_output
from tradingagents.agents.schemas import RiskAssessment
def create_aggressive_debator(llm):
@ -32,6 +34,8 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes
response = llm.invoke(prompt)
model, _ = validate_agent_output(response.content, RiskAssessment, llm)
argument = f"Aggressive Analyst: {response.content}"
new_risk_debate_state = {

View File

@ -1,3 +1,5 @@
from tradingagents.agents.output_parser import validate_agent_output
from tradingagents.agents.schemas import RiskAssessment
def create_conservative_debator(llm):
@ -32,6 +34,8 @@ Engage by questioning their optimism and emphasizing the potential downsides the
response = llm.invoke(prompt)
model, _ = validate_agent_output(response.content, RiskAssessment, llm)
argument = f"Conservative Analyst: {response.content}"
new_risk_debate_state = {

View File

@ -1,3 +1,5 @@
from tradingagents.agents.output_parser import validate_agent_output
from tradingagents.agents.schemas import RiskAssessment
def create_neutral_debator(llm):
@ -32,6 +34,8 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the
response = llm.invoke(prompt)
model, _ = validate_agent_output(response.content, RiskAssessment, llm)
argument = f"Neutral Analyst: {response.content}"
new_risk_debate_state = {

View File

@ -0,0 +1,90 @@
"""Pydantic models for structured agent outputs."""
from enum import Enum
from typing import Optional, Union
from pydantic import BaseModel, Field
class ActionSignal(str, Enum):
BUY = "Buy"
SELL = "Sell"
HOLD = "Hold"
class PortfolioRating(str, Enum):
BUY = "Buy"
OVERWEIGHT = "Overweight"
HOLD = "Hold"
UNDERWEIGHT = "Underweight"
SELL = "Sell"
class AnalystReport(BaseModel):
"""Structured output from any analyst (market, news, fundamentals, social media)."""
summary: str = Field(description="Concise summary of key findings")
detailed_analysis: str = Field(description="Full analysis with supporting evidence")
key_points: list[str] = Field(description="Bullet list of actionable insights")
confidence: float = Field(ge=0.0, le=1.0, description="Confidence level 0-1")
class TraderDecision(BaseModel):
"""Structured output from the trader agent."""
action: ActionSignal = Field(description="Trading action: Buy, Sell, or Hold")
reasoning: str = Field(description="Rationale for the decision")
confidence: float = Field(ge=0.0, le=1.0, description="Confidence level 0-1")
price_target: Optional[float] = Field(default=None, description="Target price if applicable")
stop_loss: Optional[float] = Field(default=None, description="Stop-loss price if applicable")
class RiskAssessment(BaseModel):
"""Structured output from a risk debater (aggressive, conservative, neutral)."""
stance: str = Field(description="The analyst's stance on the trade")
argument: str = Field(description="Core argument with supporting evidence")
risk_factors: list[str] = Field(description="Key risk factors identified")
confidence: float = Field(ge=0.0, le=1.0, description="Confidence level 0-1")
class PortfolioDecision(BaseModel):
"""Structured output from the portfolio manager."""
rating: PortfolioRating = Field(description="Buy / Overweight / Hold / Underweight / Sell")
executive_summary: str = Field(description="Concise action plan")
investment_thesis: str = Field(description="Detailed reasoning for the decision")
confidence: float = Field(ge=0.0, le=1.0, description="Confidence level 0-1")
price_target: Optional[float] = Field(default=None, description="Target price if applicable")
time_horizon: Optional[str] = Field(default=None, description="Recommended holding period")
# Key fields to extract from each model type (only non-text, structured data)
_EXTRACT_KEYS: dict[type[BaseModel], list[str]] = {
AnalystReport: ["confidence", "key_points"],
TraderDecision: ["action", "confidence", "price_target", "stop_loss"],
RiskAssessment: ["stance", "confidence", "risk_factors"],
PortfolioDecision: ["rating", "confidence", "price_target", "time_horizon"],
}
def extract_fields(
model: Union[AnalystReport, TraderDecision, RiskAssessment, PortfolioDecision],
) -> dict:
"""Extract key structured fields from a validated Pydantic model.
Returns a flat dict containing only the actionable structured fields
(rating, action, price targets, confidence, etc.) with None values omitted.
Enum values are converted to their string representation.
"""
keys = _EXTRACT_KEYS.get(type(model), list(model.model_fields.keys()))
result: dict = {}
for key in keys:
val = getattr(model, key)
if val is None:
continue
# Convert enums to their string value
if isinstance(val, Enum):
val = val.value
result[key] = val
return result

View File

@ -1,6 +1,8 @@
import functools
from tradingagents.agents.utils.agent_utils import build_instrument_context
from tradingagents.agents.output_parser import validate_agent_output
from tradingagents.agents.schemas import TraderDecision
def create_trader(llm, memory):
@ -38,6 +40,8 @@ def create_trader(llm, memory):
result = llm.invoke(messages)
model, _ = validate_agent_output(result.content, TraderDecision, llm)
return {
"messages": [result],
"trader_investment_plan": result.content,