TradingAgents/tradingagents/agents/utils/pydantic_validation.py

179 lines
5.6 KiB
Python

"""
Pydantic validation models for TradingAgents.
Provides strict schema validation at agent boundaries to catch
validation errors early and provide clear error messages.
Issue #434: https://github.com/TauricResearch/TradingAgents/issues/434
"""
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field, field_validator
from datetime import date as DateType
class AnalystReport(BaseModel):
"""Validated analyst report output."""
report: str = Field(
...,
min_length=10,
description="Detailed analyst report with market insights"
)
indicators_used: List[str] = Field(
default_factory=list,
description="List of technical indicators used in the analysis"
)
has_trade_proposal: bool = Field(
default=False,
description="Whether the report contains a FINAL TRANSACTION PROPOSAL"
)
@field_validator('report')
@classmethod
def validate_report_quality(cls, v: str) -> str:
if len(v.strip()) < 10:
raise ValueError("Report must be at least 10 characters")
return v.strip()
class InvestDebateStateValidated(BaseModel):
"""Validated research debate state."""
bull_history: str = Field(default="", description="Bullish conversation history")
bear_history: str = Field(default="", description="Bearish conversation history")
history: str = Field(default="", description="Full conversation history")
current_response: str = Field(default="", description="Latest response")
judge_decision: str = Field(default="", description="Final judge decision")
count: int = Field(default=0, ge=0, description="Conversation length")
@field_validator('judge_decision')
@classmethod
def validate_judge_decision(cls, v: str) -> str:
if v and v.strip().upper() not in ['', 'BUY', 'SELL', 'HOLD']:
# Allow any text but warn about non-standard decisions
pass
return v
class RiskDebateStateValidated(BaseModel):
"""Validated risk management debate state."""
aggressive_history: str = Field(default="", description="Aggressive agent history")
conservative_history: str = Field(default="", description="Conservative agent history")
neutral_history: str = Field(default="", description="Neutral agent history")
history: str = Field(default="", description="Full conversation history")
latest_speaker: str = Field(default="", description="Last speaker")
current_aggressive_response: str = Field(default="")
current_conservative_response: str = Field(default="")
current_neutral_response: str = Field(default="")
judge_decision: str = Field(default="", description="Judge's decision")
count: int = Field(default=0, ge=0, description="Conversation length")
class TradeDecision(BaseModel):
"""Validated final trade decision."""
decision: str = Field(
...,
description="Trade decision: BUY, SELL, or HOLD"
)
confidence: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Confidence level (0-1)"
)
reasoning: str = Field(
default="",
description="Brief reasoning for the decision"
)
@field_validator('decision')
@classmethod
def validate_decision(cls, v: str) -> str:
v = v.strip().upper()
if v not in ['BUY', 'SELL', 'HOLD']:
raise ValueError(f"Decision must be BUY, SELL, or HOLD, got: {v}")
return v
class AgentInput(BaseModel):
"""Validated agent input state."""
company_of_interest: str = Field(
...,
min_length=1,
max_length=10,
description="Stock ticker symbol"
)
trade_date: str = Field(
...,
description="Trading date in YYYY-MM-DD format"
)
@field_validator('company_of_interest')
@classmethod
def validate_ticker(cls, v: str) -> str:
v = v.strip().upper()
if not v.isalpha():
raise ValueError(f"Ticker must be alphabetic, got: {v}")
return v
@field_validator('trade_date')
@classmethod
def validate_date(cls, v: str) -> str:
try:
DateType.fromisoformat(v)
except ValueError:
raise ValueError(f"Invalid date format: {v}. Expected YYYY-MM-DD")
return v
def validate_agent_output(
output: Dict[str, Any],
model_class: type[BaseModel]
) -> Dict[str, Any]:
"""
Validate agent output against a Pydantic model.
Args:
output: Raw agent output dictionary
model_class: Pydantic model class to validate against
Returns:
Validated dictionary
Raises:
ValueError: If validation fails
"""
try:
validated = model_class(**output)
return validated.model_dump()
except Exception as e:
raise ValueError(
f"Agent output validation failed for {model_class.__name__}: {e}"
)
def safe_validate_agent_output(
output: Dict[str, Any],
model_class: type[BaseModel]
) -> Dict[str, Any]:
"""
Safely validate agent output with fallback.
If validation fails, returns the original output with an error field.
Does not raise exceptions.
"""
try:
validated = model_class(**output)
result = validated.model_dump()
result['_validation_status'] = 'valid'
return result
except Exception as e:
result = dict(output) if isinstance(output, dict) else {'raw': str(output)}
result['_validation_status'] = 'invalid'
result['_validation_error'] = str(e)
return result