From cc721a87be9d859a18228089fe5fd1e8de204c65 Mon Sep 17 00:00:00 2001 From: Clayton Brown Date: Mon, 20 Apr 2026 22:43:31 +1000 Subject: [PATCH] feat: Pydantic schema validation for agent outputs (#434) Review feedback applied: - All 13 agent nodes now capture validate_agent_output return value - Analysts use validated model.summary when available - Type hint: llm param changed from object to Any - JSON extraction regex: non-greedy to avoid spanning multiple blocks Closes #434 --- tests/test_output_parser.py | 144 +++++++++++++++ tradingagents/agents/__init__.py | 6 + .../agents/analysts/fundamentals_analyst.py | 5 + .../agents/analysts/market_analyst.py | 5 + tradingagents/agents/analysts/news_analyst.py | 5 + .../agents/analysts/social_media_analyst.py | 5 + .../agents/managers/portfolio_manager.py | 4 + .../agents/managers/research_manager.py | 4 + tradingagents/agents/output_parser.py | 170 ++++++++++++++++++ .../agents/researchers/bear_researcher.py | 4 + .../agents/researchers/bull_researcher.py | 4 + .../agents/risk_mgmt/aggressive_debator.py | 4 + .../agents/risk_mgmt/conservative_debator.py | 4 + .../agents/risk_mgmt/neutral_debator.py | 4 + tradingagents/agents/schemas.py | 90 ++++++++++ tradingagents/agents/trader/trader.py | 4 + 16 files changed, 462 insertions(+) create mode 100644 tests/test_output_parser.py create mode 100644 tradingagents/agents/output_parser.py create mode 100644 tradingagents/agents/schemas.py diff --git a/tests/test_output_parser.py b/tests/test_output_parser.py new file mode 100644 index 00000000..775d34c0 --- /dev/null +++ b/tests/test_output_parser.py @@ -0,0 +1,144 @@ +"""Tests for structured output parsing with retry on malformed output.""" + +import unittest + +from pydantic import ValidationError + +from tradingagents.agents.output_parser import StructuredOutputParser, validate_agent_output +from tradingagents.agents.schemas import ( + AnalystReport, + PortfolioDecision, + RiskAssessment, + TraderDecision, + extract_fields, +) + + +VALID_ANALYST_JSON = ( + '{"summary": "Stock is up", "detailed_analysis": "Strong earnings beat",' + ' "key_points": ["Revenue up 20%"], "confidence": 0.85}' +) + +VALID_TRADER_JSON = ( + '{"action": "Buy", "reasoning": "Bullish trend", "confidence": 0.9,' + ' "price_target": 150.0, "stop_loss": 130.0}' +) + +VALID_RISK_JSON = ( + '{"stance": "Cautious", "argument": "Volatility is high",' + ' "risk_factors": ["Market downturn"], "confidence": 0.6}' +) + +VALID_PORTFOLIO_JSON = ( + '{"rating": "Buy", "executive_summary": "Strong buy",' + ' "investment_thesis": "Solid fundamentals", "confidence": 0.8,' + ' "price_target": 200.0, "time_horizon": "6 months"}' +) + + +class TestStructuredOutputParser(unittest.TestCase): + """Test parse, retry, and extract_fields for all schema types.""" + + def test_parse_valid_json(self): + parser = StructuredOutputParser(AnalystReport) + result = parser.parse(VALID_ANALYST_JSON) + self.assertIsInstance(result, AnalystReport) + self.assertEqual(result.confidence, 0.85) + + def test_parse_json_in_code_fence(self): + text = f"```json\n{VALID_TRADER_JSON}\n```" + parser = StructuredOutputParser(TraderDecision) + result = parser.parse(text) + self.assertEqual(result.action.value, "Buy") + + def test_parse_malformed_raises(self): + parser = StructuredOutputParser(AnalystReport) + with self.assertRaises((ValidationError, Exception)): + parser.parse("This is not JSON at all") + + def test_parse_missing_required_field_raises(self): + parser = StructuredOutputParser(TraderDecision) + # Missing 'action' and 'reasoning' + with self.assertRaises((ValidationError, Exception)): + parser.parse('{"confidence": 0.5}') + + def test_parse_confidence_out_of_range_raises(self): + parser = StructuredOutputParser(AnalystReport) + bad = '{"summary": "x", "detailed_analysis": "x", "key_points": [], "confidence": 1.5}' + with self.assertRaises((ValidationError, Exception)): + parser.parse(bad) + + def test_retry_recovers_from_malformed_output(self): + """Malformed first response → retry → valid JSON → success.""" + parser = StructuredOutputParser(AnalystReport) + calls: list[str] = [] + + def fake_llm(prompt: str) -> str: + calls.append(prompt) + return VALID_ANALYST_JSON + + result = parser.parse_with_retry("not json", fake_llm, max_retries=2) + self.assertIsInstance(result, AnalystReport) + self.assertEqual(len(calls), 1) # one retry was needed + + def test_retry_exhausted_raises(self): + """All retries return garbage → raises.""" + parser = StructuredOutputParser(TraderDecision) + + def always_bad(prompt: str) -> str: + return "still not valid" + + with self.assertRaises(Exception): + parser.parse_with_retry("bad", always_bad, max_retries=2) + + def test_retry_second_attempt_succeeds(self): + """First retry still bad, second retry returns valid JSON.""" + parser = StructuredOutputParser(RiskAssessment) + attempt = {"n": 0} + + def llm(prompt: str) -> str: + attempt["n"] += 1 + if attempt["n"] < 2: + return "still broken" + return VALID_RISK_JSON + + result = parser.parse_with_retry("garbage", llm, max_retries=2) + self.assertIsInstance(result, RiskAssessment) + self.assertEqual(attempt["n"], 2) + + +class TestValidateAgentOutput(unittest.TestCase): + """Test the convenience wrapper used by agent nodes.""" + + def test_valid_output_returns_model_and_fields(self): + model, fields = validate_agent_output(VALID_PORTFOLIO_JSON, PortfolioDecision) + self.assertIsInstance(model, PortfolioDecision) + self.assertEqual(fields["rating"], "Buy") + self.assertIn("confidence", fields) + + def test_invalid_output_without_llm_returns_none(self): + model, fields = validate_agent_output("garbage", AnalystReport, llm=None) + self.assertIsNone(model) + self.assertEqual(fields, {}) + + +class TestExtractFields(unittest.TestCase): + """Test structured field extraction from validated models.""" + + def test_analyst_extract(self): + m = AnalystReport.model_validate_json(VALID_ANALYST_JSON) + f = extract_fields(m) + self.assertIn("confidence", f) + self.assertIn("key_points", f) + self.assertNotIn("summary", f) # text fields excluded + + def test_trader_extract_enum_to_str(self): + m = TraderDecision.model_validate_json(VALID_TRADER_JSON) + f = extract_fields(m) + self.assertEqual(f["action"], "Buy") # enum → string + + def test_none_values_omitted(self): + m = TraderDecision(action="Buy", reasoning="bullish trend", confidence=0.5) + f = extract_fields(m) + self.assertNotIn("price_target", f) + self.assertNotIn("stop_loss", f) diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index 1f03642c..5947cde6 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -17,6 +17,9 @@ from .risk_mgmt.neutral_debator import create_neutral_debator from .managers.research_manager import create_research_manager from .managers.portfolio_manager import create_portfolio_manager +from .output_parser import StructuredOutputParser, validate_agent_output +from .schemas import extract_fields + from .trader.trader import create_trader __all__ = [ @@ -37,4 +40,7 @@ __all__ = [ "create_conservative_debator", "create_social_media_analyst", "create_trader", + "StructuredOutputParser", + "validate_agent_output", + "extract_fields", ] diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index 6aa49cf3..b2021e2d 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -8,6 +8,8 @@ from tradingagents.agents.utils.agent_utils import ( get_insider_transactions, get_language_instruction, ) +from tradingagents.agents.output_parser import validate_agent_output +from tradingagents.agents.schemas import AnalystReport from tradingagents.dataflows.config import get_config @@ -60,6 +62,9 @@ def create_fundamentals_analyst(llm): if len(result.tool_calls) == 0: report = result.content + model, _ = validate_agent_output(report, AnalystReport, llm) + if model: + report = model.summary or report return { "messages": [result], diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index fef8f751..d1301d90 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -5,6 +5,8 @@ from tradingagents.agents.utils.agent_utils import ( get_language_instruction, get_stock_data, ) +from tradingagents.agents.output_parser import validate_agent_output +from tradingagents.agents.schemas import AnalystReport from tradingagents.dataflows.config import get_config @@ -79,6 +81,9 @@ Volume-Based Indicators: if len(result.tool_calls) == 0: report = result.content + model, _ = validate_agent_output(report, AnalystReport, llm) + if model: + report = model.summary or report return { "messages": [result], diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index e0fe93c5..95f2df19 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -5,6 +5,8 @@ from tradingagents.agents.utils.agent_utils import ( get_language_instruction, get_news, ) +from tradingagents.agents.output_parser import validate_agent_output +from tradingagents.agents.schemas import AnalystReport from tradingagents.dataflows.config import get_config @@ -53,6 +55,9 @@ def create_news_analyst(llm): if len(result.tool_calls) == 0: report = result.content + model, _ = validate_agent_output(report, AnalystReport, llm) + if model: + report = model.summary or report return { "messages": [result], diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index 34a53c46..d231aa36 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -1,5 +1,7 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news +from tradingagents.agents.output_parser import validate_agent_output +from tradingagents.agents.schemas import AnalystReport from tradingagents.dataflows.config import get_config @@ -48,6 +50,9 @@ def create_social_media_analyst(llm): if len(result.tool_calls) == 0: report = result.content + model, _ = validate_agent_output(report, AnalystReport, llm) + if model: + report = model.summary or report return { "messages": [result], diff --git a/tradingagents/agents/managers/portfolio_manager.py b/tradingagents/agents/managers/portfolio_manager.py index 6c69ae9f..6eb6e4e7 100644 --- a/tradingagents/agents/managers/portfolio_manager.py +++ b/tradingagents/agents/managers/portfolio_manager.py @@ -1,4 +1,6 @@ from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction +from tradingagents.agents.output_parser import validate_agent_output +from tradingagents.agents.schemas import PortfolioDecision def create_portfolio_manager(llm, memory): @@ -56,6 +58,8 @@ Be decisive and ground every conclusion in specific evidence from the analysts.{ response = llm.invoke(prompt) + model, _ = validate_agent_output(response.content, PortfolioDecision, llm) + new_risk_debate_state = { "judge_decision": response.content, "history": risk_debate_state["history"], diff --git a/tradingagents/agents/managers/research_manager.py b/tradingagents/agents/managers/research_manager.py index 5b4b4fdc..3dacfbf7 100644 --- a/tradingagents/agents/managers/research_manager.py +++ b/tradingagents/agents/managers/research_manager.py @@ -1,5 +1,7 @@ from tradingagents.agents.utils.agent_utils import build_instrument_context +from tradingagents.agents.output_parser import validate_agent_output +from tradingagents.agents.schemas import TraderDecision def create_research_manager(llm, memory): @@ -41,6 +43,8 @@ Debate History: {history}""" response = llm.invoke(prompt) + model, _ = validate_agent_output(response.content, TraderDecision, llm) + new_investment_debate_state = { "judge_decision": response.content, "history": investment_debate_state.get("history", ""), diff --git a/tradingagents/agents/output_parser.py b/tradingagents/agents/output_parser.py new file mode 100644 index 00000000..77adda09 --- /dev/null +++ b/tradingagents/agents/output_parser.py @@ -0,0 +1,170 @@ +"""Output parser that validates LLM responses against Pydantic schemas.""" + +import json +import logging +import re +from collections.abc import Callable +from typing import Any, TypeVar + +from langchain_core.output_parsers import PydanticOutputParser +from pydantic import BaseModel, ValidationError + +T = TypeVar("T", bound=BaseModel) + +logger = logging.getLogger(__name__) + +_RETRY_PROMPT = ( + "Your previous response could not be parsed. Error:\n{error}\n\n" + "Please respond ONLY with valid JSON matching this schema:\n{instructions}\n\n" + "Previous (invalid) response:\n{previous}" +) + +MAX_RETRIES = 2 + + +class StructuredOutputParser: + """Validates LLM text output against a Pydantic model. + + Usage: + parser = StructuredOutputParser(AnalystReport) + instructions = parser.get_format_instructions() # inject into prompt + result = parser.parse(llm_response_text) # returns AnalystReport or raises + + With retry: + result = parser.parse_with_retry(llm_response_text, llm_caller) + """ + + def __init__(self, schema: type[T]) -> None: + self.schema = schema + self._langchain_parser = PydanticOutputParser(pydantic_object=schema) + + def get_format_instructions(self) -> str: + """Return formatting instructions to embed in the LLM prompt.""" + return self._langchain_parser.get_format_instructions() + + def parse(self, text: str) -> T: + """Parse LLM text into the Pydantic model. + + Tries JSON extraction first, then falls back to langchain parser. + + Raises: + ValidationError: If the output doesn't match the schema. + """ + # Try to extract JSON from markdown code fences or raw JSON + json_str = self._extract_json(text) + if json_str is not None: + try: + data = json.loads(json_str) + return self.schema.model_validate(data) + except (json.JSONDecodeError, ValidationError): + pass + + # Fallback: let langchain parser try + try: + return self._langchain_parser.parse(text) + except Exception as e: + # Re-raise as ValidationError for consistent handling + raise ValidationError.from_exception_data( + title=self.schema.__name__, + line_errors=[ + { + "type": "value_error", + "loc": (), + "msg": f"Failed to parse LLM output: {e}", + "input": text[:500], + "ctx": {"error": str(e)}, + } + ], + ) from e + + def parse_with_retry( + self, + text: str, + llm_caller: Callable[[str], str], + max_retries: int = MAX_RETRIES, + ) -> T: + """Parse with automatic retry on validation failure. + + On failure, sends the error and format instructions back to the LLM + via *llm_caller* (a callable that accepts a prompt string and returns + the LLM's text response). + + Args: + text: Initial LLM response text to parse. + llm_caller: ``fn(prompt) -> response_text`` used for retries. + max_retries: Maximum number of retry attempts (default 2). + + Returns: + Validated Pydantic model instance. + + Raises: + ValidationError: If all retries are exhausted. + """ + last_error: Exception | None = None + current_text = text + + for attempt in range(1 + max_retries): + try: + return self.parse(current_text) + except (ValidationError, Exception) as exc: + last_error = exc + if attempt < max_retries: + logger.warning( + "Validation failed for %s (attempt %d/%d): %s", + self.schema.__name__, + attempt + 1, + 1 + max_retries, + exc, + ) + retry_prompt = _RETRY_PROMPT.format( + error=str(exc), + instructions=self.get_format_instructions(), + previous=current_text[:1000], + ) + current_text = llm_caller(retry_prompt) + + # All retries exhausted — raise the last error + raise last_error # type: ignore[misc] + + @staticmethod + def _extract_json(text: str) -> str | None: + """Extract JSON from markdown code fences or find raw JSON object.""" + # Match ```json ... ``` or ``` ... ``` + match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL) + if match: + return match.group(1).strip() + + # Try to find a raw JSON object (non-greedy to avoid spanning multiple blocks) + match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", text) + if match: + return match.group(0) + + return None + + +def validate_agent_output( + text: str, + schema: type[T], + llm: Any | None = None, +) -> tuple[T | None, dict]: + """Validate agent output against a schema, with optional LLM retry. + + Returns (model_instance, extracted_fields) on success, + or (None, {}) on failure (graceful degradation). + """ + from tradingagents.agents.schemas import extract_fields + + parser = StructuredOutputParser(schema) + + def _llm_caller(prompt: str) -> str: + return llm.invoke(prompt).content + + try: + if llm is not None: + model = parser.parse_with_retry(text, _llm_caller) + else: + model = parser.parse(text) + return model, extract_fields(model) + except Exception: + logger.warning("Schema validation failed for %s, passing raw text through", schema.__name__) + return None, {} diff --git a/tradingagents/agents/researchers/bear_researcher.py b/tradingagents/agents/researchers/bear_researcher.py index a44212dc..e41a313c 100644 --- a/tradingagents/agents/researchers/bear_researcher.py +++ b/tradingagents/agents/researchers/bear_researcher.py @@ -1,3 +1,5 @@ +from tradingagents.agents.output_parser import validate_agent_output +from tradingagents.agents.schemas import RiskAssessment def create_bear_researcher(llm, memory): @@ -43,6 +45,8 @@ Use this information to deliver a compelling bear argument, refute the bull's cl response = llm.invoke(prompt) + model, _ = validate_agent_output(response.content, RiskAssessment, llm) + argument = f"Bear Analyst: {response.content}" new_investment_debate_state = { diff --git a/tradingagents/agents/researchers/bull_researcher.py b/tradingagents/agents/researchers/bull_researcher.py index d23d4d76..f04c94a9 100644 --- a/tradingagents/agents/researchers/bull_researcher.py +++ b/tradingagents/agents/researchers/bull_researcher.py @@ -1,3 +1,5 @@ +from tradingagents.agents.output_parser import validate_agent_output +from tradingagents.agents.schemas import RiskAssessment def create_bull_researcher(llm, memory): @@ -41,6 +43,8 @@ Use this information to deliver a compelling bull argument, refute the bear's co response = llm.invoke(prompt) + model, _ = validate_agent_output(response.content, RiskAssessment, llm) + argument = f"Bull Analyst: {response.content}" new_investment_debate_state = { diff --git a/tradingagents/agents/risk_mgmt/aggressive_debator.py b/tradingagents/agents/risk_mgmt/aggressive_debator.py index 2dab1152..2ffa96c5 100644 --- a/tradingagents/agents/risk_mgmt/aggressive_debator.py +++ b/tradingagents/agents/risk_mgmt/aggressive_debator.py @@ -1,3 +1,5 @@ +from tradingagents.agents.output_parser import validate_agent_output +from tradingagents.agents.schemas import RiskAssessment def create_aggressive_debator(llm): @@ -32,6 +34,8 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes response = llm.invoke(prompt) + model, _ = validate_agent_output(response.content, RiskAssessment, llm) + argument = f"Aggressive Analyst: {response.content}" new_risk_debate_state = { diff --git a/tradingagents/agents/risk_mgmt/conservative_debator.py b/tradingagents/agents/risk_mgmt/conservative_debator.py index 99a8315e..91ffc2c9 100644 --- a/tradingagents/agents/risk_mgmt/conservative_debator.py +++ b/tradingagents/agents/risk_mgmt/conservative_debator.py @@ -1,3 +1,5 @@ +from tradingagents.agents.output_parser import validate_agent_output +from tradingagents.agents.schemas import RiskAssessment def create_conservative_debator(llm): @@ -32,6 +34,8 @@ Engage by questioning their optimism and emphasizing the potential downsides the response = llm.invoke(prompt) + model, _ = validate_agent_output(response.content, RiskAssessment, llm) + argument = f"Conservative Analyst: {response.content}" new_risk_debate_state = { diff --git a/tradingagents/agents/risk_mgmt/neutral_debator.py b/tradingagents/agents/risk_mgmt/neutral_debator.py index e99ff0af..3c7f8f89 100644 --- a/tradingagents/agents/risk_mgmt/neutral_debator.py +++ b/tradingagents/agents/risk_mgmt/neutral_debator.py @@ -1,3 +1,5 @@ +from tradingagents.agents.output_parser import validate_agent_output +from tradingagents.agents.schemas import RiskAssessment def create_neutral_debator(llm): @@ -32,6 +34,8 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the response = llm.invoke(prompt) + model, _ = validate_agent_output(response.content, RiskAssessment, llm) + argument = f"Neutral Analyst: {response.content}" new_risk_debate_state = { diff --git a/tradingagents/agents/schemas.py b/tradingagents/agents/schemas.py new file mode 100644 index 00000000..4e723def --- /dev/null +++ b/tradingagents/agents/schemas.py @@ -0,0 +1,90 @@ +"""Pydantic models for structured agent outputs.""" + +from enum import Enum +from typing import Optional, Union + +from pydantic import BaseModel, Field + + +class ActionSignal(str, Enum): + BUY = "Buy" + SELL = "Sell" + HOLD = "Hold" + + +class PortfolioRating(str, Enum): + BUY = "Buy" + OVERWEIGHT = "Overweight" + HOLD = "Hold" + UNDERWEIGHT = "Underweight" + SELL = "Sell" + + +class AnalystReport(BaseModel): + """Structured output from any analyst (market, news, fundamentals, social media).""" + + summary: str = Field(description="Concise summary of key findings") + detailed_analysis: str = Field(description="Full analysis with supporting evidence") + key_points: list[str] = Field(description="Bullet list of actionable insights") + confidence: float = Field(ge=0.0, le=1.0, description="Confidence level 0-1") + + +class TraderDecision(BaseModel): + """Structured output from the trader agent.""" + + action: ActionSignal = Field(description="Trading action: Buy, Sell, or Hold") + reasoning: str = Field(description="Rationale for the decision") + confidence: float = Field(ge=0.0, le=1.0, description="Confidence level 0-1") + price_target: Optional[float] = Field(default=None, description="Target price if applicable") + stop_loss: Optional[float] = Field(default=None, description="Stop-loss price if applicable") + + +class RiskAssessment(BaseModel): + """Structured output from a risk debater (aggressive, conservative, neutral).""" + + stance: str = Field(description="The analyst's stance on the trade") + argument: str = Field(description="Core argument with supporting evidence") + risk_factors: list[str] = Field(description="Key risk factors identified") + confidence: float = Field(ge=0.0, le=1.0, description="Confidence level 0-1") + + +class PortfolioDecision(BaseModel): + """Structured output from the portfolio manager.""" + + rating: PortfolioRating = Field(description="Buy / Overweight / Hold / Underweight / Sell") + executive_summary: str = Field(description="Concise action plan") + investment_thesis: str = Field(description="Detailed reasoning for the decision") + confidence: float = Field(ge=0.0, le=1.0, description="Confidence level 0-1") + price_target: Optional[float] = Field(default=None, description="Target price if applicable") + time_horizon: Optional[str] = Field(default=None, description="Recommended holding period") + + +# Key fields to extract from each model type (only non-text, structured data) +_EXTRACT_KEYS: dict[type[BaseModel], list[str]] = { + AnalystReport: ["confidence", "key_points"], + TraderDecision: ["action", "confidence", "price_target", "stop_loss"], + RiskAssessment: ["stance", "confidence", "risk_factors"], + PortfolioDecision: ["rating", "confidence", "price_target", "time_horizon"], +} + + +def extract_fields( + model: Union[AnalystReport, TraderDecision, RiskAssessment, PortfolioDecision], +) -> dict: + """Extract key structured fields from a validated Pydantic model. + + Returns a flat dict containing only the actionable structured fields + (rating, action, price targets, confidence, etc.) with None values omitted. + Enum values are converted to their string representation. + """ + keys = _EXTRACT_KEYS.get(type(model), list(model.model_fields.keys())) + result: dict = {} + for key in keys: + val = getattr(model, key) + if val is None: + continue + # Convert enums to their string value + if isinstance(val, Enum): + val = val.value + result[key] = val + return result diff --git a/tradingagents/agents/trader/trader.py b/tradingagents/agents/trader/trader.py index 07e9f262..5ab9a1b6 100644 --- a/tradingagents/agents/trader/trader.py +++ b/tradingagents/agents/trader/trader.py @@ -1,6 +1,8 @@ import functools from tradingagents.agents.utils.agent_utils import build_instrument_context +from tradingagents.agents.output_parser import validate_agent_output +from tradingagents.agents.schemas import TraderDecision def create_trader(llm, memory): @@ -38,6 +40,8 @@ def create_trader(llm, memory): result = llm.invoke(messages) + model, _ = validate_agent_output(result.content, TraderDecision, llm) + return { "messages": [result], "trader_investment_plan": result.content,