TradingAgents/tradingagents/agents/utils/agent_states.py

144 lines
5.7 KiB
Python

from typing import Annotated, Sequence
from datetime import date, timedelta, datetime
from typing_extensions import TypedDict, Optional
from langchain_openai import ChatOpenAI
from tradingagents.agents import *
from langgraph.prebuilt import ToolNode
from langgraph.graph import END, StateGraph, START, MessagesState
from typing import Dict, List
# Researcher team state
class PortfolioPosition(TypedDict):
ticker: str
shares: int
average_cost: float
current_value: float
unrealized_pnl: float
unrealized_pnl_pct: float
entry_date: str
class InvestDebateState(TypedDict):
bull_history: Annotated[
str, "Bullish Conversation history"
] # Bullish Conversation history
bear_history: Annotated[
str, "Bearish Conversation history"
] # Bullish Conversation history
history: Annotated[str, "Conversation history"] # Conversation history
current_response: Annotated[str, "Latest response"] # Last response
judge_decision: Annotated[str, "Final judge decision"] # Last response
count: Annotated[int, "Length of the current conversation"] # Conversation length
# Enhanced Logic Fields
last_argument_invalid: Annotated[bool, "Was the last argument rejected?"]
rejection_reason: Annotated[str, "Reason for rejection"]
latest_speaker: Annotated[str, "Who spoke last (Bull/Bear)"]
confidence: Annotated[float, "Confidence in current position (0-1)"]
# Risk management team state
class RiskDebateState(TypedDict):
risky_history: Annotated[
str, "Risky Agent's Conversation history"
] # Conversation history
safe_history: Annotated[
str, "Safe Agent's Conversation history"
] # Conversation history
neutral_history: Annotated[
str, "Neutral Agent's Conversation history"
] # Conversation history
history: Annotated[str, "Conversation history"] # Conversation history
latest_speaker: Annotated[str, "Analyst that spoke last"]
current_risky_response: Annotated[
str, "Latest response by the risky analyst"
] # Last response
current_safe_response: Annotated[
str, "Latest response by the safe analyst"
] # Last response
current_neutral_response: Annotated[
str, "Latest response by the neutral analyst"
] # Last response
judge_decision: Annotated[str, "Judge's decision"]
count: Annotated[int, "Length of the current conversation"] # Conversation length
# Enhanced Logic Fields
invalid_reasoning_detected: Annotated[bool, "Was invalid reasoning detected?"]
error_message: Annotated[str, "Error message for invalid reasoning"]
def reduce_overwrite(left, right):
"""
Reducer that allows overwriting the value.
In case of concurrent identical updates (like parallel subgraphs returning inputs),
this resolves the conflict by taking the last value (which is identical).
"""
return right
# 1. Define a specific reducer for the Risk Debate Dictionary
def merge_risk_states(left: dict, right: dict) -> dict:
"""Safely merges updates from parallel risk analysts."""
if not left: return right
if not right: return left
return {**left, **right}
class AgentState(MessagesState):
company_of_interest: Annotated[str, reduce_overwrite] # "Company that we are interested in trading"
trade_date: Annotated[str, reduce_overwrite] # "What date we are trading at"
sender: Annotated[str, "Agent that sent this message"]
# research step
market_report: Annotated[str, "Report from the Market Analyst"]
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
news_report: Annotated[
str, "Report from the News Researcher of current world affairs"
]
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
# regime data
# regime data
market_regime: Annotated[str, "Current Market Regime (e.g. VOLATILE, TRENDING_UP)"]
broad_market_regime: Annotated[str, "Broad Market Context (e.g. SPY Regime)"]
regime_metrics: Annotated[dict, "Metrics used to determine regime"]
volatility_score: Annotated[float, "Current Volatility Score"]
net_insider_flow: Annotated[float, "Net Insider Transaction Flow (Last 90 Days)"]
portfolio: Annotated[Dict[str, PortfolioPosition], "Current active holdings"]
cash_balance: Annotated[float, "Current cash balance"]
risk_multiplier: Annotated[float, "Calculated Risk Multiplier based on Relative Strength"]
# researcher team discussion step
investment_debate_state: Annotated[
InvestDebateState, "Current state of the debate on if to invest or not"
]
investment_plan: Annotated[str, "Plan generated by the Analyst"]
trader_investment_plan: Annotated[str, "Plan generated by the Trader"]
# risk management team discussion step
risk_debate_state: Annotated[
RiskDebateState, merge_risk_states
]
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
# --- STRICT ANALYST STATES FOR SUBGRAPHS ---
# These ensure parallel analysts cannot touch global state (portfolio, risk, etc.)
class BaseAnalystState(MessagesState):
"""Base state for an isolated analyst subgraph.
Inherits 'messages' from MessagesState.
"""
company_of_interest: Annotated[str, reduce_overwrite]
trade_date: Annotated[str, reduce_overwrite]
sender: Annotated[str, "Agent name (internal to subgraph)"]
class SocialAnalystState(BaseAnalystState):
sentiment_report: Annotated[str, "Output report"]
class NewsAnalystState(BaseAnalystState):
news_report: Annotated[str, "Output report"]
# Additional news-specific fields if needed, but keeping it minimal
class FundamentalsAnalystState(BaseAnalystState):
fundamentals_report: Annotated[str, "Output report"]