Merge cc721a87be into fa4d01c23a
This commit is contained in:
commit
64480eea27
|
|
@ -0,0 +1,144 @@
|
|||
"""Tests for structured output parsing with retry on malformed output."""
|
||||
|
||||
import unittest
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from tradingagents.agents.output_parser import StructuredOutputParser, validate_agent_output
|
||||
from tradingagents.agents.schemas import (
|
||||
AnalystReport,
|
||||
PortfolioDecision,
|
||||
RiskAssessment,
|
||||
TraderDecision,
|
||||
extract_fields,
|
||||
)
|
||||
|
||||
|
||||
VALID_ANALYST_JSON = (
|
||||
'{"summary": "Stock is up", "detailed_analysis": "Strong earnings beat",'
|
||||
' "key_points": ["Revenue up 20%"], "confidence": 0.85}'
|
||||
)
|
||||
|
||||
VALID_TRADER_JSON = (
|
||||
'{"action": "Buy", "reasoning": "Bullish trend", "confidence": 0.9,'
|
||||
' "price_target": 150.0, "stop_loss": 130.0}'
|
||||
)
|
||||
|
||||
VALID_RISK_JSON = (
|
||||
'{"stance": "Cautious", "argument": "Volatility is high",'
|
||||
' "risk_factors": ["Market downturn"], "confidence": 0.6}'
|
||||
)
|
||||
|
||||
VALID_PORTFOLIO_JSON = (
|
||||
'{"rating": "Buy", "executive_summary": "Strong buy",'
|
||||
' "investment_thesis": "Solid fundamentals", "confidence": 0.8,'
|
||||
' "price_target": 200.0, "time_horizon": "6 months"}'
|
||||
)
|
||||
|
||||
|
||||
class TestStructuredOutputParser(unittest.TestCase):
|
||||
"""Test parse, retry, and extract_fields for all schema types."""
|
||||
|
||||
def test_parse_valid_json(self):
|
||||
parser = StructuredOutputParser(AnalystReport)
|
||||
result = parser.parse(VALID_ANALYST_JSON)
|
||||
self.assertIsInstance(result, AnalystReport)
|
||||
self.assertEqual(result.confidence, 0.85)
|
||||
|
||||
def test_parse_json_in_code_fence(self):
|
||||
text = f"```json\n{VALID_TRADER_JSON}\n```"
|
||||
parser = StructuredOutputParser(TraderDecision)
|
||||
result = parser.parse(text)
|
||||
self.assertEqual(result.action.value, "Buy")
|
||||
|
||||
def test_parse_malformed_raises(self):
|
||||
parser = StructuredOutputParser(AnalystReport)
|
||||
with self.assertRaises((ValidationError, Exception)):
|
||||
parser.parse("This is not JSON at all")
|
||||
|
||||
def test_parse_missing_required_field_raises(self):
|
||||
parser = StructuredOutputParser(TraderDecision)
|
||||
# Missing 'action' and 'reasoning'
|
||||
with self.assertRaises((ValidationError, Exception)):
|
||||
parser.parse('{"confidence": 0.5}')
|
||||
|
||||
def test_parse_confidence_out_of_range_raises(self):
|
||||
parser = StructuredOutputParser(AnalystReport)
|
||||
bad = '{"summary": "x", "detailed_analysis": "x", "key_points": [], "confidence": 1.5}'
|
||||
with self.assertRaises((ValidationError, Exception)):
|
||||
parser.parse(bad)
|
||||
|
||||
def test_retry_recovers_from_malformed_output(self):
|
||||
"""Malformed first response → retry → valid JSON → success."""
|
||||
parser = StructuredOutputParser(AnalystReport)
|
||||
calls: list[str] = []
|
||||
|
||||
def fake_llm(prompt: str) -> str:
|
||||
calls.append(prompt)
|
||||
return VALID_ANALYST_JSON
|
||||
|
||||
result = parser.parse_with_retry("not json", fake_llm, max_retries=2)
|
||||
self.assertIsInstance(result, AnalystReport)
|
||||
self.assertEqual(len(calls), 1) # one retry was needed
|
||||
|
||||
def test_retry_exhausted_raises(self):
|
||||
"""All retries return garbage → raises."""
|
||||
parser = StructuredOutputParser(TraderDecision)
|
||||
|
||||
def always_bad(prompt: str) -> str:
|
||||
return "still not valid"
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
parser.parse_with_retry("bad", always_bad, max_retries=2)
|
||||
|
||||
def test_retry_second_attempt_succeeds(self):
|
||||
"""First retry still bad, second retry returns valid JSON."""
|
||||
parser = StructuredOutputParser(RiskAssessment)
|
||||
attempt = {"n": 0}
|
||||
|
||||
def llm(prompt: str) -> str:
|
||||
attempt["n"] += 1
|
||||
if attempt["n"] < 2:
|
||||
return "still broken"
|
||||
return VALID_RISK_JSON
|
||||
|
||||
result = parser.parse_with_retry("garbage", llm, max_retries=2)
|
||||
self.assertIsInstance(result, RiskAssessment)
|
||||
self.assertEqual(attempt["n"], 2)
|
||||
|
||||
|
||||
class TestValidateAgentOutput(unittest.TestCase):
|
||||
"""Test the convenience wrapper used by agent nodes."""
|
||||
|
||||
def test_valid_output_returns_model_and_fields(self):
|
||||
model, fields = validate_agent_output(VALID_PORTFOLIO_JSON, PortfolioDecision)
|
||||
self.assertIsInstance(model, PortfolioDecision)
|
||||
self.assertEqual(fields["rating"], "Buy")
|
||||
self.assertIn("confidence", fields)
|
||||
|
||||
def test_invalid_output_without_llm_returns_none(self):
|
||||
model, fields = validate_agent_output("garbage", AnalystReport, llm=None)
|
||||
self.assertIsNone(model)
|
||||
self.assertEqual(fields, {})
|
||||
|
||||
|
||||
class TestExtractFields(unittest.TestCase):
|
||||
"""Test structured field extraction from validated models."""
|
||||
|
||||
def test_analyst_extract(self):
|
||||
m = AnalystReport.model_validate_json(VALID_ANALYST_JSON)
|
||||
f = extract_fields(m)
|
||||
self.assertIn("confidence", f)
|
||||
self.assertIn("key_points", f)
|
||||
self.assertNotIn("summary", f) # text fields excluded
|
||||
|
||||
def test_trader_extract_enum_to_str(self):
|
||||
m = TraderDecision.model_validate_json(VALID_TRADER_JSON)
|
||||
f = extract_fields(m)
|
||||
self.assertEqual(f["action"], "Buy") # enum → string
|
||||
|
||||
def test_none_values_omitted(self):
|
||||
m = TraderDecision(action="Buy", reasoning="bullish trend", confidence=0.5)
|
||||
f = extract_fields(m)
|
||||
self.assertNotIn("price_target", f)
|
||||
self.assertNotIn("stop_loss", f)
|
||||
|
|
@ -17,6 +17,9 @@ from .risk_mgmt.neutral_debator import create_neutral_debator
|
|||
from .managers.research_manager import create_research_manager
|
||||
from .managers.portfolio_manager import create_portfolio_manager
|
||||
|
||||
from .output_parser import StructuredOutputParser, validate_agent_output
|
||||
from .schemas import extract_fields
|
||||
|
||||
from .trader.trader import create_trader
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -37,4 +40,7 @@ __all__ = [
|
|||
"create_conservative_debator",
|
||||
"create_social_media_analyst",
|
||||
"create_trader",
|
||||
"StructuredOutputParser",
|
||||
"validate_agent_output",
|
||||
"extract_fields",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_insider_transactions,
|
||||
get_language_instruction,
|
||||
)
|
||||
from tradingagents.agents.output_parser import validate_agent_output
|
||||
from tradingagents.agents.schemas import AnalystReport
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -60,6 +62,9 @@ def create_fundamentals_analyst(llm):
|
|||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
model, _ = validate_agent_output(report, AnalystReport, llm)
|
||||
if model:
|
||||
report = model.summary or report
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_language_instruction,
|
||||
get_stock_data,
|
||||
)
|
||||
from tradingagents.agents.output_parser import validate_agent_output
|
||||
from tradingagents.agents.schemas import AnalystReport
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -79,6 +81,9 @@ Volume-Based Indicators:
|
|||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
model, _ = validate_agent_output(report, AnalystReport, llm)
|
||||
if model:
|
||||
report = model.summary or report
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_language_instruction,
|
||||
get_news,
|
||||
)
|
||||
from tradingagents.agents.output_parser import validate_agent_output
|
||||
from tradingagents.agents.schemas import AnalystReport
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -53,6 +55,9 @@ def create_news_analyst(llm):
|
|||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
model, _ = validate_agent_output(report, AnalystReport, llm)
|
||||
if model:
|
||||
report = model.summary or report
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
|
||||
from tradingagents.agents.output_parser import validate_agent_output
|
||||
from tradingagents.agents.schemas import AnalystReport
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -48,6 +50,9 @@ def create_social_media_analyst(llm):
|
|||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
model, _ = validate_agent_output(report, AnalystReport, llm)
|
||||
if model:
|
||||
report = model.summary or report
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction
|
||||
from tradingagents.agents.output_parser import validate_agent_output
|
||||
from tradingagents.agents.schemas import PortfolioDecision
|
||||
|
||||
|
||||
def create_portfolio_manager(llm, memory):
|
||||
|
|
@ -56,6 +58,8 @@ Be decisive and ground every conclusion in specific evidence from the analysts.{
|
|||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
model, _ = validate_agent_output(response.content, PortfolioDecision, llm)
|
||||
|
||||
new_risk_debate_state = {
|
||||
"judge_decision": response.content,
|
||||
"history": risk_debate_state["history"],
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
from tradingagents.agents.output_parser import validate_agent_output
|
||||
from tradingagents.agents.schemas import TraderDecision
|
||||
|
||||
|
||||
def create_research_manager(llm, memory):
|
||||
|
|
@ -41,6 +43,8 @@ Debate History:
|
|||
{history}"""
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
model, _ = validate_agent_output(response.content, TraderDecision, llm)
|
||||
|
||||
new_investment_debate_state = {
|
||||
"judge_decision": response.content,
|
||||
"history": investment_debate_state.get("history", ""),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,170 @@
|
|||
"""Output parser that validates LLM responses against Pydantic schemas."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_RETRY_PROMPT = (
|
||||
"Your previous response could not be parsed. Error:\n{error}\n\n"
|
||||
"Please respond ONLY with valid JSON matching this schema:\n{instructions}\n\n"
|
||||
"Previous (invalid) response:\n{previous}"
|
||||
)
|
||||
|
||||
MAX_RETRIES = 2
|
||||
|
||||
|
||||
class StructuredOutputParser:
|
||||
"""Validates LLM text output against a Pydantic model.
|
||||
|
||||
Usage:
|
||||
parser = StructuredOutputParser(AnalystReport)
|
||||
instructions = parser.get_format_instructions() # inject into prompt
|
||||
result = parser.parse(llm_response_text) # returns AnalystReport or raises
|
||||
|
||||
With retry:
|
||||
result = parser.parse_with_retry(llm_response_text, llm_caller)
|
||||
"""
|
||||
|
||||
def __init__(self, schema: type[T]) -> None:
|
||||
self.schema = schema
|
||||
self._langchain_parser = PydanticOutputParser(pydantic_object=schema)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
"""Return formatting instructions to embed in the LLM prompt."""
|
||||
return self._langchain_parser.get_format_instructions()
|
||||
|
||||
def parse(self, text: str) -> T:
|
||||
"""Parse LLM text into the Pydantic model.
|
||||
|
||||
Tries JSON extraction first, then falls back to langchain parser.
|
||||
|
||||
Raises:
|
||||
ValidationError: If the output doesn't match the schema.
|
||||
"""
|
||||
# Try to extract JSON from markdown code fences or raw JSON
|
||||
json_str = self._extract_json(text)
|
||||
if json_str is not None:
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
return self.schema.model_validate(data)
|
||||
except (json.JSONDecodeError, ValidationError):
|
||||
pass
|
||||
|
||||
# Fallback: let langchain parser try
|
||||
try:
|
||||
return self._langchain_parser.parse(text)
|
||||
except Exception as e:
|
||||
# Re-raise as ValidationError for consistent handling
|
||||
raise ValidationError.from_exception_data(
|
||||
title=self.schema.__name__,
|
||||
line_errors=[
|
||||
{
|
||||
"type": "value_error",
|
||||
"loc": (),
|
||||
"msg": f"Failed to parse LLM output: {e}",
|
||||
"input": text[:500],
|
||||
"ctx": {"error": str(e)},
|
||||
}
|
||||
],
|
||||
) from e
|
||||
|
||||
def parse_with_retry(
|
||||
self,
|
||||
text: str,
|
||||
llm_caller: Callable[[str], str],
|
||||
max_retries: int = MAX_RETRIES,
|
||||
) -> T:
|
||||
"""Parse with automatic retry on validation failure.
|
||||
|
||||
On failure, sends the error and format instructions back to the LLM
|
||||
via *llm_caller* (a callable that accepts a prompt string and returns
|
||||
the LLM's text response).
|
||||
|
||||
Args:
|
||||
text: Initial LLM response text to parse.
|
||||
llm_caller: ``fn(prompt) -> response_text`` used for retries.
|
||||
max_retries: Maximum number of retry attempts (default 2).
|
||||
|
||||
Returns:
|
||||
Validated Pydantic model instance.
|
||||
|
||||
Raises:
|
||||
ValidationError: If all retries are exhausted.
|
||||
"""
|
||||
last_error: Exception | None = None
|
||||
current_text = text
|
||||
|
||||
for attempt in range(1 + max_retries):
|
||||
try:
|
||||
return self.parse(current_text)
|
||||
except (ValidationError, Exception) as exc:
|
||||
last_error = exc
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
"Validation failed for %s (attempt %d/%d): %s",
|
||||
self.schema.__name__,
|
||||
attempt + 1,
|
||||
1 + max_retries,
|
||||
exc,
|
||||
)
|
||||
retry_prompt = _RETRY_PROMPT.format(
|
||||
error=str(exc),
|
||||
instructions=self.get_format_instructions(),
|
||||
previous=current_text[:1000],
|
||||
)
|
||||
current_text = llm_caller(retry_prompt)
|
||||
|
||||
# All retries exhausted — raise the last error
|
||||
raise last_error # type: ignore[misc]
|
||||
|
||||
@staticmethod
|
||||
def _extract_json(text: str) -> str | None:
|
||||
"""Extract JSON from markdown code fences or find raw JSON object."""
|
||||
# Match ```json ... ``` or ``` ... ```
|
||||
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# Try to find a raw JSON object (non-greedy to avoid spanning multiple blocks)
|
||||
match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", text)
|
||||
if match:
|
||||
return match.group(0)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_agent_output(
|
||||
text: str,
|
||||
schema: type[T],
|
||||
llm: Any | None = None,
|
||||
) -> tuple[T | None, dict]:
|
||||
"""Validate agent output against a schema, with optional LLM retry.
|
||||
|
||||
Returns (model_instance, extracted_fields) on success,
|
||||
or (None, {}) on failure (graceful degradation).
|
||||
"""
|
||||
from tradingagents.agents.schemas import extract_fields
|
||||
|
||||
parser = StructuredOutputParser(schema)
|
||||
|
||||
def _llm_caller(prompt: str) -> str:
|
||||
return llm.invoke(prompt).content
|
||||
|
||||
try:
|
||||
if llm is not None:
|
||||
model = parser.parse_with_retry(text, _llm_caller)
|
||||
else:
|
||||
model = parser.parse(text)
|
||||
return model, extract_fields(model)
|
||||
except Exception:
|
||||
logger.warning("Schema validation failed for %s, passing raw text through", schema.__name__)
|
||||
return None, {}
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
from tradingagents.agents.output_parser import validate_agent_output
|
||||
from tradingagents.agents.schemas import RiskAssessment
|
||||
|
||||
|
||||
def create_bear_researcher(llm, memory):
|
||||
|
|
@ -43,6 +45,8 @@ Use this information to deliver a compelling bear argument, refute the bull's cl
|
|||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
model, _ = validate_agent_output(response.content, RiskAssessment, llm)
|
||||
|
||||
argument = f"Bear Analyst: {response.content}"
|
||||
|
||||
new_investment_debate_state = {
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from tradingagents.agents.output_parser import validate_agent_output
|
||||
from tradingagents.agents.schemas import RiskAssessment
|
||||
|
||||
|
||||
def create_bull_researcher(llm, memory):
|
||||
|
|
@ -41,6 +43,8 @@ Use this information to deliver a compelling bull argument, refute the bear's co
|
|||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
model, _ = validate_agent_output(response.content, RiskAssessment, llm)
|
||||
|
||||
argument = f"Bull Analyst: {response.content}"
|
||||
|
||||
new_investment_debate_state = {
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from tradingagents.agents.output_parser import validate_agent_output
|
||||
from tradingagents.agents.schemas import RiskAssessment
|
||||
|
||||
|
||||
def create_aggressive_debator(llm):
|
||||
|
|
@ -32,6 +34,8 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes
|
|||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
model, _ = validate_agent_output(response.content, RiskAssessment, llm)
|
||||
|
||||
argument = f"Aggressive Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from tradingagents.agents.output_parser import validate_agent_output
|
||||
from tradingagents.agents.schemas import RiskAssessment
|
||||
|
||||
|
||||
def create_conservative_debator(llm):
|
||||
|
|
@ -32,6 +34,8 @@ Engage by questioning their optimism and emphasizing the potential downsides the
|
|||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
model, _ = validate_agent_output(response.content, RiskAssessment, llm)
|
||||
|
||||
argument = f"Conservative Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from tradingagents.agents.output_parser import validate_agent_output
|
||||
from tradingagents.agents.schemas import RiskAssessment
|
||||
|
||||
|
||||
def create_neutral_debator(llm):
|
||||
|
|
@ -32,6 +34,8 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the
|
|||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
model, _ = validate_agent_output(response.content, RiskAssessment, llm)
|
||||
|
||||
argument = f"Neutral Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,90 @@
|
|||
"""Pydantic models for structured agent outputs."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ActionSignal(str, Enum):
|
||||
BUY = "Buy"
|
||||
SELL = "Sell"
|
||||
HOLD = "Hold"
|
||||
|
||||
|
||||
class PortfolioRating(str, Enum):
|
||||
BUY = "Buy"
|
||||
OVERWEIGHT = "Overweight"
|
||||
HOLD = "Hold"
|
||||
UNDERWEIGHT = "Underweight"
|
||||
SELL = "Sell"
|
||||
|
||||
|
||||
class AnalystReport(BaseModel):
|
||||
"""Structured output from any analyst (market, news, fundamentals, social media)."""
|
||||
|
||||
summary: str = Field(description="Concise summary of key findings")
|
||||
detailed_analysis: str = Field(description="Full analysis with supporting evidence")
|
||||
key_points: list[str] = Field(description="Bullet list of actionable insights")
|
||||
confidence: float = Field(ge=0.0, le=1.0, description="Confidence level 0-1")
|
||||
|
||||
|
||||
class TraderDecision(BaseModel):
|
||||
"""Structured output from the trader agent."""
|
||||
|
||||
action: ActionSignal = Field(description="Trading action: Buy, Sell, or Hold")
|
||||
reasoning: str = Field(description="Rationale for the decision")
|
||||
confidence: float = Field(ge=0.0, le=1.0, description="Confidence level 0-1")
|
||||
price_target: Optional[float] = Field(default=None, description="Target price if applicable")
|
||||
stop_loss: Optional[float] = Field(default=None, description="Stop-loss price if applicable")
|
||||
|
||||
|
||||
class RiskAssessment(BaseModel):
|
||||
"""Structured output from a risk debater (aggressive, conservative, neutral)."""
|
||||
|
||||
stance: str = Field(description="The analyst's stance on the trade")
|
||||
argument: str = Field(description="Core argument with supporting evidence")
|
||||
risk_factors: list[str] = Field(description="Key risk factors identified")
|
||||
confidence: float = Field(ge=0.0, le=1.0, description="Confidence level 0-1")
|
||||
|
||||
|
||||
class PortfolioDecision(BaseModel):
|
||||
"""Structured output from the portfolio manager."""
|
||||
|
||||
rating: PortfolioRating = Field(description="Buy / Overweight / Hold / Underweight / Sell")
|
||||
executive_summary: str = Field(description="Concise action plan")
|
||||
investment_thesis: str = Field(description="Detailed reasoning for the decision")
|
||||
confidence: float = Field(ge=0.0, le=1.0, description="Confidence level 0-1")
|
||||
price_target: Optional[float] = Field(default=None, description="Target price if applicable")
|
||||
time_horizon: Optional[str] = Field(default=None, description="Recommended holding period")
|
||||
|
||||
|
||||
# Key fields to extract from each model type (only non-text, structured data)
|
||||
_EXTRACT_KEYS: dict[type[BaseModel], list[str]] = {
|
||||
AnalystReport: ["confidence", "key_points"],
|
||||
TraderDecision: ["action", "confidence", "price_target", "stop_loss"],
|
||||
RiskAssessment: ["stance", "confidence", "risk_factors"],
|
||||
PortfolioDecision: ["rating", "confidence", "price_target", "time_horizon"],
|
||||
}
|
||||
|
||||
|
||||
def extract_fields(
|
||||
model: Union[AnalystReport, TraderDecision, RiskAssessment, PortfolioDecision],
|
||||
) -> dict:
|
||||
"""Extract key structured fields from a validated Pydantic model.
|
||||
|
||||
Returns a flat dict containing only the actionable structured fields
|
||||
(rating, action, price targets, confidence, etc.) with None values omitted.
|
||||
Enum values are converted to their string representation.
|
||||
"""
|
||||
keys = _EXTRACT_KEYS.get(type(model), list(model.model_fields.keys()))
|
||||
result: dict = {}
|
||||
for key in keys:
|
||||
val = getattr(model, key)
|
||||
if val is None:
|
||||
continue
|
||||
# Convert enums to their string value
|
||||
if isinstance(val, Enum):
|
||||
val = val.value
|
||||
result[key] = val
|
||||
return result
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
import functools
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
from tradingagents.agents.output_parser import validate_agent_output
|
||||
from tradingagents.agents.schemas import TraderDecision
|
||||
|
||||
|
||||
def create_trader(llm, memory):
|
||||
|
|
@ -38,6 +40,8 @@ def create_trader(llm, memory):
|
|||
|
||||
result = llm.invoke(messages)
|
||||
|
||||
model, _ = validate_agent_output(result.content, TraderDecision, llm)
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"trader_investment_plan": result.content,
|
||||
|
|
|
|||
Loading…
Reference in New Issue