Merge 5450fc9e81 into fa4d01c23a
This commit is contained in:
commit
9a381bbc25
|
|
@ -0,0 +1,49 @@
|
||||||
|
"""
|
||||||
|
Example: Using Pydantic validation in analyst agents.
|
||||||
|
|
||||||
|
This demonstrates how to add validation at agent boundaries
|
||||||
|
to catch errors early and provide clear feedback.
|
||||||
|
|
||||||
|
Issue #434: https://github.com/TauricResearch/TradingAgents/issues/434
|
||||||
|
"""
|
||||||
|
|
||||||
|
from tradingagents.agents.utils.pydantic_validation import (
|
||||||
|
AnalystReport,
|
||||||
|
safe_validate_agent_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_market_analyst_with_validation(llm):
|
||||||
|
"""
|
||||||
|
Enhanced market analyst with Pydantic validation.
|
||||||
|
|
||||||
|
Wraps the standard market analyst to validate outputs
|
||||||
|
and provide clear error messages when validation fails.
|
||||||
|
"""
|
||||||
|
from tradingagents.agents.analysts.market_analyst import create_market_analyst
|
||||||
|
|
||||||
|
# Get the original analyst
|
||||||
|
original_analyst = create_market_analyst(llm)
|
||||||
|
|
||||||
|
def validated_market_analyst_node(state):
|
||||||
|
# Run the original analyst
|
||||||
|
result = original_analyst(state)
|
||||||
|
|
||||||
|
# Validate the output
|
||||||
|
if isinstance(result, dict):
|
||||||
|
validated = safe_validate_agent_output(result, AnalystReport)
|
||||||
|
|
||||||
|
if validated.get('_validation_status') == 'invalid':
|
||||||
|
# Log validation error but continue with original output
|
||||||
|
print(f"⚠️ Validation warning: {validated.get('_validation_error')}")
|
||||||
|
|
||||||
|
return validated
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return validated_market_analyst_node
|
||||||
|
|
||||||
|
|
||||||
|
# Usage example:
|
||||||
|
# from tradingagents.agents.utils.pydantic_validation import create_market_analyst_with_validation
|
||||||
|
# validated_analyst = create_market_analyst_with_validation(llm)
|
||||||
|
|
@ -29,6 +29,7 @@ dependencies = [
|
||||||
"stockstats>=0.6.5",
|
"stockstats>=0.6.5",
|
||||||
"tqdm>=4.67.1",
|
"tqdm>=4.67.1",
|
||||||
"typing-extensions>=4.14.0",
|
"typing-extensions>=4.14.0",
|
||||||
|
"pydantic>=2.0.0",
|
||||||
"yfinance>=0.2.63",
|
"yfinance>=0.2.63",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,3 +70,19 @@ class AgentState(MessagesState):
|
||||||
RiskDebateState, "Current state of the debate on evaluating risk"
|
RiskDebateState, "Current state of the debate on evaluating risk"
|
||||||
]
|
]
|
||||||
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
||||||
|
|
||||||
|
|
||||||
|
# Import Pydantic validation helpers (Issue #434)
|
||||||
|
try:
|
||||||
|
from tradingagents.agents.utils.pydantic_validation import (
|
||||||
|
AnalystReport,
|
||||||
|
InvestDebateStateValidated,
|
||||||
|
RiskDebateStateValidated,
|
||||||
|
TradeDecision,
|
||||||
|
AgentInput,
|
||||||
|
validate_agent_output,
|
||||||
|
safe_validate_agent_output,
|
||||||
|
)
|
||||||
|
HAS_PYDANTIC_VALIDATION = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_PYDANTIC_VALIDATION = False
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,178 @@
|
||||||
|
"""
|
||||||
|
Pydantic validation models for TradingAgents.
|
||||||
|
|
||||||
|
Provides strict schema validation at agent boundaries to catch
|
||||||
|
validation errors early and provide clear error messages.
|
||||||
|
|
||||||
|
Issue #434: https://github.com/TauricResearch/TradingAgents/issues/434
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from datetime import date as DateType
|
||||||
|
|
||||||
|
|
||||||
|
class AnalystReport(BaseModel):
|
||||||
|
"""Validated analyst report output."""
|
||||||
|
|
||||||
|
report: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=10,
|
||||||
|
description="Detailed analyst report with market insights"
|
||||||
|
)
|
||||||
|
indicators_used: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="List of technical indicators used in the analysis"
|
||||||
|
)
|
||||||
|
has_trade_proposal: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether the report contains a FINAL TRANSACTION PROPOSAL"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator('report')
|
||||||
|
@classmethod
|
||||||
|
def validate_report_quality(cls, v: str) -> str:
|
||||||
|
if len(v.strip()) < 10:
|
||||||
|
raise ValueError("Report must be at least 10 characters")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class InvestDebateStateValidated(BaseModel):
|
||||||
|
"""Validated research debate state."""
|
||||||
|
|
||||||
|
bull_history: str = Field(default="", description="Bullish conversation history")
|
||||||
|
bear_history: str = Field(default="", description="Bearish conversation history")
|
||||||
|
history: str = Field(default="", description="Full conversation history")
|
||||||
|
current_response: str = Field(default="", description="Latest response")
|
||||||
|
judge_decision: str = Field(default="", description="Final judge decision")
|
||||||
|
count: int = Field(default=0, ge=0, description="Conversation length")
|
||||||
|
|
||||||
|
@field_validator('judge_decision')
|
||||||
|
@classmethod
|
||||||
|
def validate_judge_decision(cls, v: str) -> str:
|
||||||
|
if v and v.strip().upper() not in ['', 'BUY', 'SELL', 'HOLD']:
|
||||||
|
# Allow any text but warn about non-standard decisions
|
||||||
|
pass
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class RiskDebateStateValidated(BaseModel):
|
||||||
|
"""Validated risk management debate state."""
|
||||||
|
|
||||||
|
aggressive_history: str = Field(default="", description="Aggressive agent history")
|
||||||
|
conservative_history: str = Field(default="", description="Conservative agent history")
|
||||||
|
neutral_history: str = Field(default="", description="Neutral agent history")
|
||||||
|
history: str = Field(default="", description="Full conversation history")
|
||||||
|
latest_speaker: str = Field(default="", description="Last speaker")
|
||||||
|
current_aggressive_response: str = Field(default="")
|
||||||
|
current_conservative_response: str = Field(default="")
|
||||||
|
current_neutral_response: str = Field(default="")
|
||||||
|
judge_decision: str = Field(default="", description="Judge's decision")
|
||||||
|
count: int = Field(default=0, ge=0, description="Conversation length")
|
||||||
|
|
||||||
|
|
||||||
|
class TradeDecision(BaseModel):
|
||||||
|
"""Validated final trade decision."""
|
||||||
|
|
||||||
|
decision: str = Field(
|
||||||
|
...,
|
||||||
|
description="Trade decision: BUY, SELL, or HOLD"
|
||||||
|
)
|
||||||
|
confidence: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Confidence level (0-1)"
|
||||||
|
)
|
||||||
|
reasoning: str = Field(
|
||||||
|
default="",
|
||||||
|
description="Brief reasoning for the decision"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator('decision')
|
||||||
|
@classmethod
|
||||||
|
def validate_decision(cls, v: str) -> str:
|
||||||
|
v = v.strip().upper()
|
||||||
|
if v not in ['BUY', 'SELL', 'HOLD']:
|
||||||
|
raise ValueError(f"Decision must be BUY, SELL, or HOLD, got: {v}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInput(BaseModel):
|
||||||
|
"""Validated agent input state."""
|
||||||
|
|
||||||
|
company_of_interest: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=1,
|
||||||
|
max_length=10,
|
||||||
|
description="Stock ticker symbol"
|
||||||
|
)
|
||||||
|
trade_date: str = Field(
|
||||||
|
...,
|
||||||
|
description="Trading date in YYYY-MM-DD format"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator('company_of_interest')
|
||||||
|
@classmethod
|
||||||
|
def validate_ticker(cls, v: str) -> str:
|
||||||
|
v = v.strip().upper()
|
||||||
|
if not v.isalpha():
|
||||||
|
raise ValueError(f"Ticker must be alphabetic, got: {v}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator('trade_date')
|
||||||
|
@classmethod
|
||||||
|
def validate_date(cls, v: str) -> str:
|
||||||
|
try:
|
||||||
|
DateType.fromisoformat(v)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Invalid date format: {v}. Expected YYYY-MM-DD")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def validate_agent_output(
|
||||||
|
output: Dict[str, Any],
|
||||||
|
model_class: type[BaseModel]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Validate agent output against a Pydantic model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: Raw agent output dictionary
|
||||||
|
model_class: Pydantic model class to validate against
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated dictionary
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If validation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
validated = model_class(**output)
|
||||||
|
return validated.model_dump()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Agent output validation failed for {model_class.__name__}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_validate_agent_output(
|
||||||
|
output: Dict[str, Any],
|
||||||
|
model_class: type[BaseModel]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Safely validate agent output with fallback.
|
||||||
|
|
||||||
|
If validation fails, returns the original output with an error field.
|
||||||
|
Does not raise exceptions.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
validated = model_class(**output)
|
||||||
|
result = validated.model_dump()
|
||||||
|
result['_validation_status'] = 'valid'
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
result = dict(output) if isinstance(output, dict) else {'raw': str(output)}
|
||||||
|
result['_validation_status'] = 'invalid'
|
||||||
|
result['_validation_error'] = str(e)
|
||||||
|
return result
|
||||||
Loading…
Reference in New Issue