TradingAgents/tradingagents/graph/conditional_logic.py

92 lines
3.6 KiB
Python

# TradingAgents/graph/conditional_logic.py
from tradingagents.agents.utils.agent_states import AgentState
class ConditionalLogic:
"""Handles conditional logic for determining graph flow."""
def __init__(self, max_debate_rounds=1, max_risk_discuss_rounds=1):
"""Initialize with configuration parameters."""
self.max_debate_rounds = max_debate_rounds
self.max_risk_discuss_rounds = max_risk_discuss_rounds
def _check_critical_abort(self, state: AgentState, report_field: str) -> bool:
"""Check if a report contains [CRITICAL ABORT] trigger.
Args:
state: The current agent state
report_field: The field name to check (e.g., 'market_report', 'fundamentals_report')
Returns:
True if [CRITICAL ABORT] is found in the report, False otherwise
"""
report = state.get(report_field, "")
if not report:
return False
return "[CRITICAL ABORT]" in report
def should_continue_market(self, state: AgentState):
"""Determine if market analysis should continue."""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools_market"
return "Msg Clear Market"
def should_continue_social(self, state: AgentState):
"""Determine if social media analysis should continue."""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools_social"
return "Msg Clear Social"
def should_continue_news(self, state: AgentState):
"""Determine if news analysis should continue."""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools_news"
return "Msg Clear News"
def should_continue_fundamentals(self, state: AgentState):
"""Determine if fundamentals analysis should continue."""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools_fundamentals"
return "Msg Clear Fundamentals"
def should_continue_debate(self, state: AgentState) -> str:
"""Determine if debate should continue."""
# Check for critical abort in market_report or fundamentals_report
if self._check_critical_abort(state, "market_report") or self._check_critical_abort(state, "fundamentals_report"):
return "Portfolio Manager"
if (
state["investment_debate_state"]["count"] >= 2 * self.max_debate_rounds
): # 3 rounds of back-and-forth between 2 agents
return "Research Manager"
if state["investment_debate_state"]["current_response"].startswith("Bull"):
return "Bear Researcher"
return "Bull Researcher"
def should_continue_risk_analysis(self, state: AgentState) -> str:
"""Determine if risk analysis should continue."""
# Check for critical abort in any report
if self._check_critical_abort(state, "market_report") or self._check_critical_abort(state, "fundamentals_report"):
return "Portfolio Manager"
if (
state["risk_debate_state"]["count"] >= 3 * self.max_risk_discuss_rounds
): # 3 rounds of back-and-forth between 3 agents
return "Portfolio Manager"
if state["risk_debate_state"]["latest_speaker"].startswith("Aggressive"):
return "Conservative Analyst"
if state["risk_debate_state"]["latest_speaker"].startswith("Conservative"):
return "Neutral Analyst"
return "Aggressive Analyst"