TradingAgents/tradingagents/models.py

470 lines
16 KiB
Python

"""Pydantic structured output models for the equity ranking engine.
All agents return one of these models. Deterministic scoring functions
compute master_score, confidence penalties, position roles, and hard vetoes
using only structured outputs — no prose drives downstream decisions.
"""
from __future__ import annotations
import json
import logging
from typing import List, Literal, Optional, Tuple
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------
RiskLevel = Literal["low", "medium", "high"]
PositionRole = Literal[
"Core Position",
"Strong Position",
"Tactical / Satellite",
"Watchlist",
"Avoid",
]
Archetype = Literal[
"Infrastructure Builder",
"Bottleneck Supplier",
"Platform Company",
"Commodity Leverage",
"Secular Growth Innovator",
"Turnaround",
"Defensive Compounder",
]
# ---------------------------------------------------------------------------
# Data quality flag
# ---------------------------------------------------------------------------
class DataFlag(BaseModel):
"""A data quality issue discovered during analysis."""
field: str
severity: Literal["minor", "moderate", "severe"]
message: str
# ---------------------------------------------------------------------------
# Base output (inherited by most agents)
# ---------------------------------------------------------------------------
class AgentBaseOutput(BaseModel):
"""Common structured fields every analyst agent must return."""
agent_name: str = ""
score_0_to_10: float = Field(default=5.0, ge=0, le=10)
confidence_0_to_1: float = Field(default=0.5, ge=0, le=1)
key_positives: List[str] = Field(default_factory=list)
key_negatives: List[str] = Field(default_factory=list)
key_risks: List[str] = Field(default_factory=list)
data_quality_flags: List[DataFlag] = Field(default_factory=list)
veto: bool = False
veto_reason: Optional[str] = None
summary_1_sentence: str = ""
# ---------------------------------------------------------------------------
# Tier 1 outputs
# ---------------------------------------------------------------------------
class ValidationOutput(BaseModel):
"""Ticker validation and identity check."""
agent_name: str = "Validation"
ticker_valid: bool = True
ticker_resolved: str = ""
company_name: str = ""
company_name_match: bool = True
exchange: Optional[str] = None
sector: Optional[str] = None
industry: Optional[str] = None
is_active: bool = True
veto: bool = False
veto_reason: Optional[str] = None
data_quality_flags: List[DataFlag] = Field(default_factory=list)
class CompanyCard(BaseModel):
"""Cached company identity card."""
company_name: str = ""
ticker: str = ""
sector: str = "Unknown"
industry: str = "Unknown"
description: str = ""
market_cap: Optional[float] = None
market_cap_formatted: Optional[str] = None
market_cap_category: str = "unknown"
current_price: Optional[float] = None
revenue: Optional[float] = None
profit_margins: Optional[float] = None
employees: Optional[int] = None
competitors: List[str] = Field(default_factory=list)
class MacroRegimeOutput(AgentBaseOutput):
"""Macro regime assessment."""
agent_name: str = "Macro Regime"
vix_level: Optional[float] = None
vix_regime: str = "unknown"
ten_year_yield: Optional[float] = None
dollar_strength: str = "unknown"
credit_spread_direction: str = "unknown"
spy_1m_return: Optional[float] = None
regime_label: str = "unknown"
macro_alignment_0_to_10: float = Field(default=5.0, ge=0, le=10)
# Regime awareness
risk_appetite: str = "neutral" # risk-on / risk-off / transitional
liquidity_regime: str = "neutral" # expansion / contraction / neutral
regime_score_adjustment: float = Field(
default=0.0, ge=-10, le=10,
description="Adjustment applied to the 0-100 master score. "
"+10 = strong macro tailwind, -10 = severe macro headwind.",
)
class LiquidityOutput(AgentBaseOutput):
"""Liquidity and market conditions."""
agent_name: str = "Liquidity"
fed_stance: str = "unknown"
market_breadth: str = "unknown"
volume_profile: str = "unknown"
spy_trend: str = "unknown"
# ---------------------------------------------------------------------------
# Tier 2 outputs
# ---------------------------------------------------------------------------
class BusinessQualityOutput(AgentBaseOutput):
"""Business quality and competitive position."""
agent_name: str = "Business Quality"
revenue_growth: Optional[float] = None
profit_margins: Optional[float] = None
operating_margins: Optional[float] = None
return_on_equity: Optional[float] = None
return_on_assets: Optional[float] = None
debt_to_equity: Optional[float] = None
free_cashflow: Optional[float] = None
competitive_moat: str = "unknown"
management_quality: str = "unknown"
class InstitutionalFlowOutput(AgentBaseOutput):
"""Institutional ownership and flow signals."""
agent_name: str = "Institutional Flow"
institutional_ownership_pct: Optional[float] = None
insider_ownership_pct: Optional[float] = None
volume_ratio: Optional[float] = None
short_interest_pct: Optional[float] = None
short_ratio: Optional[float] = None
float_turnover_pct: Optional[float] = None
accumulation_signal: str = "unknown"
# Smart-money tracking
top_holders_change: str = "unknown" # increasing / decreasing / stable
fund_accumulation_pattern: str = "unknown" # accumulating / distributing / holding
short_interest_trend: str = "unknown" # rising / falling / stable
insider_transaction_signal: str = "unknown" # buying / selling / none
smart_money_signal: str = "unknown" # bullish / bearish / neutral
class ValuationOutput(AgentBaseOutput):
"""Valuation metrics and verdict."""
agent_name: str = "Valuation"
trailing_pe: Optional[float] = None
forward_pe: Optional[float] = None
peg_ratio: Optional[float] = None
price_to_book: Optional[float] = None
ev_to_ebitda: Optional[float] = None
price_to_sales: Optional[float] = None
vs_52w_range_pct: Optional[float] = None
valuation_verdict: str = "unknown"
class EntryTimingOutput(AgentBaseOutput):
"""Technical entry timing signals."""
agent_name: str = "Entry Timing"
current_price: Optional[float] = None
fifty_day_avg: Optional[float] = None
two_hundred_day_avg: Optional[float] = None
fifty_day_vs_200_day: str = "unknown"
vs_52w_range_pct: Optional[float] = None
timing_verdict: str = "unknown"
class EarningsRevisionOutput(AgentBaseOutput):
"""Earnings revision and analyst consensus."""
agent_name: str = "Earnings Revisions"
trailing_eps: Optional[float] = None
forward_eps: Optional[float] = None
eps_revision_direction: str = "unknown"
revenue_revision_direction: str = "unknown"
analyst_consensus: str = "unknown"
price_target_upside_pct: Optional[float] = None
num_analysts: Optional[int] = None
class SectorRotationOutput(AgentBaseOutput):
"""Sector rotation and relative strength."""
agent_name: str = "Sector Rotation"
sector: str = "Unknown"
sector_etf: Optional[str] = None
sector_vs_spy_1m: Optional[float] = None
sector_vs_spy_3m: Optional[float] = None
sector_rank: Optional[int] = None
total_sectors: int = 11
rotation_direction: str = "unknown"
class BacklogOrderMomentumOutput(AgentBaseOutput):
"""Backlog and order momentum (where applicable)."""
agent_name: str = "Backlog / Order Momentum"
has_backlog_data: bool = False
backlog_trend: str = "unknown"
order_momentum: str = "unknown"
class NarrativeCrowdingOutput(AgentBaseOutput):
"""Narrative crowding and contrarian signals."""
agent_name: str = "Narrative Crowding"
narrative_saturation: str = "unknown"
contrarian_opportunity: bool = False
media_sentiment: str = "unknown"
short_squeeze_potential: bool = False
class ArchetypeOutput(BaseModel):
"""Company archetype classification."""
archetype: str = "Secular Growth Innovator"
archetype_confidence: float = Field(default=0.5, ge=0, le=1)
reasoning: str = ""
# ---------------------------------------------------------------------------
# Tier 3 outputs
# ---------------------------------------------------------------------------
class BullCaseOutput(BaseModel):
"""Structured bull case thesis."""
thesis: str = ""
catalysts: List[str] = Field(default_factory=list)
upside_target: Optional[float] = None
upside_pct: Optional[float] = None
key_assumptions: List[str] = Field(default_factory=list)
thesis_invalidation_triggers: List[str] = Field(default_factory=list)
confidence_0_to_1: float = Field(default=0.5, ge=0, le=1)
class BearCaseOutput(BaseModel):
"""Structured bear case thesis."""
thesis: str = ""
risks: List[str] = Field(default_factory=list)
downside_target: Optional[float] = None
downside_pct: Optional[float] = None
key_assumptions: List[str] = Field(default_factory=list)
thesis_invalidation_triggers: List[str] = Field(default_factory=list)
confidence_0_to_1: float = Field(default=0.5, ge=0, le=1)
class DebateRefereeOutput(BaseModel):
"""Debate referee decision."""
winner: str = "bull"
reasoning: str = ""
bull_strength_0_to_10: float = Field(default=5.0, ge=0, le=10)
bear_strength_0_to_10: float = Field(default=5.0, ge=0, le=10)
key_unresolved_questions: List[str] = Field(default_factory=list)
net_conviction_adjustment: float = Field(default=0.0, ge=-2, le=2)
class RiskInvalidationOutput(AgentBaseOutput):
"""Risk assessment and invalidation triggers."""
agent_name: str = "Risk / Invalidation"
overall_risk_level: str = "medium"
max_position_size_pct: float = Field(default=5.0, ge=0, le=100)
stop_loss_pct: Optional[float] = None
invalidation_triggers: List[str] = Field(default_factory=list)
class FinalDecisionOutput(BaseModel):
"""Final synthesized decision with narrative."""
ticker: str = ""
company_name: str = ""
master_score: float = 0.0
adjusted_score: float = 0.0
confidence: float = 0.0
position_role: str = "Avoid"
action: str = "AVOID"
risk_level: str = "medium"
thesis_summary: str = ""
key_catalysts: List[str] = Field(default_factory=list)
key_risks: List[str] = Field(default_factory=list)
invalidation_triggers: List[str] = Field(default_factory=list)
position_sizing_pct: float = 0.0
narrative: str = ""
# ---------------------------------------------------------------------------
# Theme Substitution & Position Replacement outputs
# ---------------------------------------------------------------------------
class ThemeStock(BaseModel):
"""A stock ranked within a theme."""
ticker: str
company_name: str = ""
master_score_estimate: float = Field(default=5.0, ge=0, le=10)
key_advantage: str = ""
key_weakness: str = ""
class ThemeSubstitutionOutput(BaseModel):
"""Identifies whether a stock is the best expression of its theme."""
theme_name: str = ""
theme_description: str = ""
theme_stocks_ranked: List[ThemeStock] = Field(default_factory=list)
best_expression_of_theme: bool = True
best_expression_ticker: str = ""
stronger_alternatives: List[str] = Field(default_factory=list)
relative_score_gap: float = 0.0
portfolio_overlap_warning: str = ""
reasoning: str = ""
class PositionReplacementOutput(BaseModel):
"""Identifies when a new stock is a better use of capital."""
replace_candidate: str = ""
replace_with: str = ""
score_difference: float = 0.0
theme_overlap: str = ""
replacement_reason: str = ""
conviction_level: str = "low" # low / medium / high
stronger_on: List[str] = Field(
default_factory=list,
description="Dimensions where candidate beats replacement target",
)
weaker_on: List[str] = Field(default_factory=list)
should_replace: bool = False
# ---------------------------------------------------------------------------
# Deterministic scoring functions
# ---------------------------------------------------------------------------
def compute_master_score(
business_quality: float,
macro_alignment: float,
institutional_flow: float,
valuation: float,
entry_timing: float,
earnings_revisions: float,
backlog: float,
crowding: float,
regime_adjustment: float = 0.0,
) -> float:
"""Compute weighted master score (0-100).
Weights:
25% business_quality
20% macro_alignment
15% institutional_flow
10% valuation
10% entry_timing
10% earnings_revisions
5% backlog
5% crowding
regime_adjustment: -2 to +2, applied as direct offset to the 0-100 score.
"""
weighted = (
0.25 * business_quality
+ 0.20 * macro_alignment
+ 0.15 * institutional_flow
+ 0.10 * valuation
+ 0.10 * entry_timing
+ 0.10 * earnings_revisions
+ 0.05 * backlog
+ 0.05 * crowding
)
raw = weighted * 10
adjusted = max(0.0, min(100.0, raw + regime_adjustment))
return round(adjusted, 2)
def assign_position_role(score: float) -> str:
"""Map master score to a position role."""
if score > 80:
return "Core Position"
if score > 70:
return "Strong Position"
if score > 60:
return "Tactical / Satellite"
if score > 50:
return "Watchlist"
return "Avoid"
def apply_confidence_penalty(
base_score: float,
flags: List[DataFlag],
hard_veto: bool,
) -> float:
"""Reduce score based on data quality flags.
Penalties: minor=-0.5, moderate=-1.0, severe=-2.0
"""
if hard_veto:
return 0.0
penalty = 0.0
for flag in flags:
if flag.severity == "minor":
penalty += 0.5
elif flag.severity == "moderate":
penalty += 1.0
elif flag.severity == "severe":
penalty += 2.0
return max(0.0, base_score - penalty)
def should_hard_veto(
validation: Optional[ValidationOutput],
risk: Optional[RiskInvalidationOutput] = None,
) -> Tuple[bool, Optional[str]]:
"""Check if analysis should be hard-vetoed."""
if validation is not None:
if validation.veto:
return True, validation.veto_reason or "Validation veto"
if not validation.company_name_match:
return True, "Company name mismatch"
if not validation.ticker_valid:
return True, "Invalid ticker"
if risk is not None and risk.veto:
return True, risk.veto_reason or "Risk veto"
return False, None
# ---------------------------------------------------------------------------
# LLM structured output helper
# ---------------------------------------------------------------------------
def invoke_structured(llm, model_cls, prompt: str):
"""Call LLM with structured output, with JSON fallback."""
try:
structured = llm.with_structured_output(model_cls)
return structured.invoke(prompt)
except Exception as e:
logger.warning("Structured output failed for %s: %s — using JSON fallback", model_cls.__name__, e)
schema_str = json.dumps(model_cls.model_json_schema(), indent=2)
json_prompt = (
f"{prompt}\n\nReturn ONLY valid JSON matching this schema:\n{schema_str}"
)
response = llm.invoke(json_prompt)
content = response.content.strip()
if "```json" in content:
content = content.split("```json")[1].split("```")[0].strip()
elif "```" in content:
content = content.split("```")[1].split("```")[0].strip()
return model_cls.model_validate_json(content)