This commit is contained in:
LaBoon 2026-04-14 20:52:33 -05:00 committed by GitHub
commit 9a381bbc25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 244 additions and 0 deletions

View File

@ -0,0 +1,49 @@
"""
Example: Using Pydantic validation in analyst agents.
This demonstrates how to add validation at agent boundaries
to catch errors early and provide clear feedback.
Issue #434: https://github.com/TauricResearch/TradingAgents/issues/434
"""
from tradingagents.agents.utils.pydantic_validation import (
AnalystReport,
safe_validate_agent_output,
)
def create_market_analyst_with_validation(llm):
"""
Enhanced market analyst with Pydantic validation.
Wraps the standard market analyst to validate outputs
and provide clear error messages when validation fails.
"""
from tradingagents.agents.analysts.market_analyst import create_market_analyst
# Get the original analyst
original_analyst = create_market_analyst(llm)
def validated_market_analyst_node(state):
# Run the original analyst
result = original_analyst(state)
# Validate the output
if isinstance(result, dict):
validated = safe_validate_agent_output(result, AnalystReport)
if validated.get('_validation_status') == 'invalid':
# Log validation error but continue with original output
print(f"⚠️ Validation warning: {validated.get('_validation_error')}")
return validated
return result
return validated_market_analyst_node
# Usage example:
# from tradingagents.agents.utils.pydantic_validation import create_market_analyst_with_validation
# validated_analyst = create_market_analyst_with_validation(llm)

View File

@ -29,6 +29,7 @@ dependencies = [
"stockstats>=0.6.5", "stockstats>=0.6.5",
"tqdm>=4.67.1", "tqdm>=4.67.1",
"typing-extensions>=4.14.0", "typing-extensions>=4.14.0",
"pydantic>=2.0.0",
"yfinance>=0.2.63", "yfinance>=0.2.63",
] ]

View File

@ -70,3 +70,19 @@ class AgentState(MessagesState):
RiskDebateState, "Current state of the debate on evaluating risk" RiskDebateState, "Current state of the debate on evaluating risk"
] ]
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"] final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
# Import Pydantic validation helpers (Issue #434)
try:
from tradingagents.agents.utils.pydantic_validation import (
AnalystReport,
InvestDebateStateValidated,
RiskDebateStateValidated,
TradeDecision,
AgentInput,
validate_agent_output,
safe_validate_agent_output,
)
HAS_PYDANTIC_VALIDATION = True
except ImportError:
HAS_PYDANTIC_VALIDATION = False

View File

@ -0,0 +1,178 @@
"""
Pydantic validation models for TradingAgents.
Provides strict schema validation at agent boundaries to catch
validation errors early and provide clear error messages.
Issue #434: https://github.com/TauricResearch/TradingAgents/issues/434
"""
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field, field_validator
from datetime import date as DateType
class AnalystReport(BaseModel):
"""Validated analyst report output."""
report: str = Field(
...,
min_length=10,
description="Detailed analyst report with market insights"
)
indicators_used: List[str] = Field(
default_factory=list,
description="List of technical indicators used in the analysis"
)
has_trade_proposal: bool = Field(
default=False,
description="Whether the report contains a FINAL TRANSACTION PROPOSAL"
)
@field_validator('report')
@classmethod
def validate_report_quality(cls, v: str) -> str:
if len(v.strip()) < 10:
raise ValueError("Report must be at least 10 characters")
return v.strip()
class InvestDebateStateValidated(BaseModel):
"""Validated research debate state."""
bull_history: str = Field(default="", description="Bullish conversation history")
bear_history: str = Field(default="", description="Bearish conversation history")
history: str = Field(default="", description="Full conversation history")
current_response: str = Field(default="", description="Latest response")
judge_decision: str = Field(default="", description="Final judge decision")
count: int = Field(default=0, ge=0, description="Conversation length")
@field_validator('judge_decision')
@classmethod
def validate_judge_decision(cls, v: str) -> str:
if v and v.strip().upper() not in ['', 'BUY', 'SELL', 'HOLD']:
# Allow any text but warn about non-standard decisions
pass
return v
class RiskDebateStateValidated(BaseModel):
"""Validated risk management debate state."""
aggressive_history: str = Field(default="", description="Aggressive agent history")
conservative_history: str = Field(default="", description="Conservative agent history")
neutral_history: str = Field(default="", description="Neutral agent history")
history: str = Field(default="", description="Full conversation history")
latest_speaker: str = Field(default="", description="Last speaker")
current_aggressive_response: str = Field(default="")
current_conservative_response: str = Field(default="")
current_neutral_response: str = Field(default="")
judge_decision: str = Field(default="", description="Judge's decision")
count: int = Field(default=0, ge=0, description="Conversation length")
class TradeDecision(BaseModel):
"""Validated final trade decision."""
decision: str = Field(
...,
description="Trade decision: BUY, SELL, or HOLD"
)
confidence: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Confidence level (0-1)"
)
reasoning: str = Field(
default="",
description="Brief reasoning for the decision"
)
@field_validator('decision')
@classmethod
def validate_decision(cls, v: str) -> str:
v = v.strip().upper()
if v not in ['BUY', 'SELL', 'HOLD']:
raise ValueError(f"Decision must be BUY, SELL, or HOLD, got: {v}")
return v
class AgentInput(BaseModel):
"""Validated agent input state."""
company_of_interest: str = Field(
...,
min_length=1,
max_length=10,
description="Stock ticker symbol"
)
trade_date: str = Field(
...,
description="Trading date in YYYY-MM-DD format"
)
@field_validator('company_of_interest')
@classmethod
def validate_ticker(cls, v: str) -> str:
v = v.strip().upper()
if not v.isalpha():
raise ValueError(f"Ticker must be alphabetic, got: {v}")
return v
@field_validator('trade_date')
@classmethod
def validate_date(cls, v: str) -> str:
try:
DateType.fromisoformat(v)
except ValueError:
raise ValueError(f"Invalid date format: {v}. Expected YYYY-MM-DD")
return v
def validate_agent_output(
output: Dict[str, Any],
model_class: type[BaseModel]
) -> Dict[str, Any]:
"""
Validate agent output against a Pydantic model.
Args:
output: Raw agent output dictionary
model_class: Pydantic model class to validate against
Returns:
Validated dictionary
Raises:
ValueError: If validation fails
"""
try:
validated = model_class(**output)
return validated.model_dump()
except Exception as e:
raise ValueError(
f"Agent output validation failed for {model_class.__name__}: {e}"
)
def safe_validate_agent_output(
output: Dict[str, Any],
model_class: type[BaseModel]
) -> Dict[str, Any]:
"""
Safely validate agent output with fallback.
If validation fails, returns the original output with an error field.
Does not raise exceptions.
"""
try:
validated = model_class(**output)
result = validated.model_dump()
result['_validation_status'] = 'valid'
return result
except Exception as e:
result = dict(output) if isinstance(output, dict) else {'raw': str(output)}
result['_validation_status'] = 'invalid'
result['_validation_error'] = str(e)
return result