diff --git a/tests/test_agent_states.py b/tests/test_agent_states.py new file mode 100644 index 00000000..e31d4182 --- /dev/null +++ b/tests/test_agent_states.py @@ -0,0 +1,40 @@ +"""Tests for updated agent state definitions.""" + + +def test_invest_debate_state_has_timing_fields(): + from tradingagents.agents.utils.agent_states import InvestDebateState + keys = InvestDebateState.__annotations__ + assert "yes_history" in keys + assert "no_history" in keys + assert "timing_history" in keys + assert "latest_speaker" in keys + assert "current_yes_response" in keys + assert "current_no_response" in keys + assert "current_timing_response" in keys + assert "bull_history" not in keys + assert "bear_history" not in keys + + +def test_agent_state_has_polymarket_fields(): + from tradingagents.agents.utils.agent_states import AgentState + keys = AgentState.__annotations__ + assert "event_id" in keys + assert "event_question" in keys + assert "odds_report" in keys + assert "event_report" in keys + assert "trader_plan" in keys + assert "final_decision" in keys + assert "company_of_interest" not in keys + assert "market_report" not in keys + assert "fundamentals_report" not in keys + assert "trader_investment_plan" not in keys + assert "final_trade_decision" not in keys + + +def test_risk_debate_state_unchanged(): + from tradingagents.agents.utils.agent_states import RiskDebateState + keys = RiskDebateState.__annotations__ + assert "aggressive_history" in keys + assert "conservative_history" in keys + assert "neutral_history" in keys + assert "latest_speaker" in keys diff --git a/tests/test_conditional_logic.py b/tests/test_conditional_logic.py new file mode 100644 index 00000000..71190545 --- /dev/null +++ b/tests/test_conditional_logic.py @@ -0,0 +1,69 @@ +"""Tests for conditional logic routing.""" +from unittest.mock import MagicMock + + +def test_debate_routes_yes_to_no(): + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1) + state = {"investment_debate_state": {"count": 1, "latest_speaker": "YES Advocate"}} + assert cl.should_continue_debate(state) == "NO Advocate" + + +def test_debate_routes_no_to_timing(): + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1) + state = {"investment_debate_state": {"count": 2, "latest_speaker": "NO Advocate"}} + assert cl.should_continue_debate(state) == "Timing Advocate" + + +def test_debate_routes_timing_to_yes(): + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic(max_debate_rounds=2, max_risk_discuss_rounds=1) + state = {"investment_debate_state": {"count": 3, "latest_speaker": "Timing Advocate"}} + assert cl.should_continue_debate(state) == "YES Advocate" + + +def test_debate_routes_to_manager_after_max_rounds(): + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1) + state = {"investment_debate_state": {"count": 3, "latest_speaker": "Timing Advocate"}} + assert cl.should_continue_debate(state) == "Research Manager" + + +def test_debate_initial_routes_to_yes(): + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1) + state = {"investment_debate_state": {"count": 0, "latest_speaker": ""}} + assert cl.should_continue_debate(state) == "YES Advocate" + + +def test_should_continue_odds_routes_to_tools(): + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1) + msg = MagicMock() + msg.tool_calls = [{"name": "get_market_data", "args": {}}] + state = {"messages": [msg]} + assert cl.should_continue_odds(state) == "tools_odds" + + +def test_should_continue_odds_routes_to_clear(): + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1) + msg = MagicMock() + msg.tool_calls = [] + state = {"messages": [msg]} + assert cl.should_continue_odds(state) == "Msg Clear Odds" + + +def test_risk_routes_aggressive_to_conservative(): + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1) + state = {"risk_debate_state": {"count": 1, "latest_speaker": "Aggressive Analyst"}} + assert cl.should_continue_risk_analysis(state) == "Conservative Analyst" + + +def test_risk_routes_to_judge_after_max(): + from tradingagents.graph.conditional_logic import ConditionalLogic + cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1) + state = {"risk_debate_state": {"count": 3, "latest_speaker": "Neutral Analyst"}} + assert cl.should_continue_risk_analysis(state) == "Risk Judge" diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 813b00ee..49f275f4 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -1,76 +1,47 @@ -from typing import Annotated, Sequence -from datetime import date, timedelta, datetime -from typing_extensions import TypedDict, Optional -from langchain_openai import ChatOpenAI -from tradingagents.agents import * -from langgraph.prebuilt import ToolNode -from langgraph.graph import END, StateGraph, START, MessagesState +from langgraph.graph import MessagesState +from typing_extensions import TypedDict -# Researcher team state class InvestDebateState(TypedDict): - bull_history: Annotated[ - str, "Bullish Conversation history" - ] # Bullish Conversation history - bear_history: Annotated[ - str, "Bearish Conversation history" - ] # Bullish Conversation history - history: Annotated[str, "Conversation history"] # Conversation history - current_response: Annotated[str, "Latest response"] # Last response - judge_decision: Annotated[str, "Final judge decision"] # Last response - count: Annotated[int, "Length of the current conversation"] # Conversation length + """State for the YES/NO/Timing investment debate.""" + yes_history: str + no_history: str + timing_history: str + history: str + current_yes_response: str + current_no_response: str + current_timing_response: str + latest_speaker: str + judge_decision: str + count: int -# Risk management team state class RiskDebateState(TypedDict): - aggressive_history: Annotated[ - str, "Aggressive Agent's Conversation history" - ] # Conversation history - conservative_history: Annotated[ - str, "Conservative Agent's Conversation history" - ] # Conversation history - neutral_history: Annotated[ - str, "Neutral Agent's Conversation history" - ] # Conversation history - history: Annotated[str, "Conversation history"] # Conversation history - latest_speaker: Annotated[str, "Analyst that spoke last"] - current_aggressive_response: Annotated[ - str, "Latest response by the aggressive analyst" - ] # Last response - current_conservative_response: Annotated[ - str, "Latest response by the conservative analyst" - ] # Last response - current_neutral_response: Annotated[ - str, "Latest response by the neutral analyst" - ] # Last response - judge_decision: Annotated[str, "Judge's decision"] - count: Annotated[int, "Length of the current conversation"] # Conversation length + """State for the Aggressive/Conservative/Neutral risk debate.""" + aggressive_history: str + conservative_history: str + neutral_history: str + history: str + latest_speaker: str + current_aggressive_response: str + current_conservative_response: str + current_neutral_response: str + judge_decision: str + count: int class AgentState(MessagesState): - company_of_interest: Annotated[str, "Company that we are interested in trading"] - trade_date: Annotated[str, "What date we are trading at"] - - sender: Annotated[str, "Agent that sent this message"] - - # research step - market_report: Annotated[str, "Report from the Market Analyst"] - sentiment_report: Annotated[str, "Report from the Social Media Analyst"] - news_report: Annotated[ - str, "Report from the News Researcher of current world affairs" - ] - fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"] - - # researcher team discussion step - investment_debate_state: Annotated[ - InvestDebateState, "Current state of the debate on if to invest or not" - ] - investment_plan: Annotated[str, "Plan generated by the Analyst"] - - trader_investment_plan: Annotated[str, "Plan generated by the Trader"] - - # risk management team discussion step - risk_debate_state: Annotated[ - RiskDebateState, "Current state of the debate on evaluating risk" - ] - final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"] + """Main agent state for Polymarket prediction analysis.""" + event_id: str + event_question: str + trade_date: str + sender: str + odds_report: str + sentiment_report: str + news_report: str + event_report: str + investment_debate_state: InvestDebateState + investment_plan: str + trader_plan: str + risk_debate_state: RiskDebateState + final_decision: str diff --git a/tradingagents/graph/conditional_logic.py b/tradingagents/graph/conditional_logic.py index 7b1b1f90..92a9cc43 100644 --- a/tradingagents/graph/conditional_logic.py +++ b/tradingagents/graph/conditional_logic.py @@ -1,67 +1,64 @@ -# TradingAgents/graph/conditional_logic.py - -from tradingagents.agents.utils.agent_states import AgentState +"""Conditional routing logic for the trading agents graph.""" class ConditionalLogic: - """Handles conditional logic for determining graph flow.""" + """Handles conditional routing decisions in the graph.""" def __init__(self, max_debate_rounds=1, max_risk_discuss_rounds=1): - """Initialize with configuration parameters.""" self.max_debate_rounds = max_debate_rounds self.max_risk_discuss_rounds = max_risk_discuss_rounds - def should_continue_market(self, state: AgentState): - """Determine if market analysis should continue.""" + def should_continue_odds(self, state): messages = state["messages"] last_message = messages[-1] if last_message.tool_calls: - return "tools_market" - return "Msg Clear Market" + return "tools_odds" + return "Msg Clear Odds" - def should_continue_social(self, state: AgentState): - """Determine if social media analysis should continue.""" + def should_continue_social(self, state): messages = state["messages"] last_message = messages[-1] if last_message.tool_calls: return "tools_social" return "Msg Clear Social" - def should_continue_news(self, state: AgentState): - """Determine if news analysis should continue.""" + def should_continue_news(self, state): messages = state["messages"] last_message = messages[-1] if last_message.tool_calls: return "tools_news" return "Msg Clear News" - def should_continue_fundamentals(self, state: AgentState): - """Determine if fundamentals analysis should continue.""" + def should_continue_event(self, state): messages = state["messages"] last_message = messages[-1] if last_message.tool_calls: - return "tools_fundamentals" - return "Msg Clear Fundamentals" + return "tools_event" + return "Msg Clear Event" - def should_continue_debate(self, state: AgentState) -> str: - """Determine if debate should continue.""" - - if ( - state["investment_debate_state"]["count"] >= 2 * self.max_debate_rounds - ): # 3 rounds of back-and-forth between 2 agents + def should_continue_debate(self, state): + """Route 3-way YES/NO/Timing debate. Mirrors risk debate pattern.""" + count = state["investment_debate_state"]["count"] + if count >= 3 * self.max_debate_rounds: return "Research Manager" - if state["investment_debate_state"]["current_response"].startswith("Bull"): - return "Bear Researcher" - return "Bull Researcher" + latest = state["investment_debate_state"].get("latest_speaker", "") + if latest.startswith("YES"): + return "NO Advocate" + elif latest.startswith("NO"): + return "Timing Advocate" + else: + # Initial entry or after Timing -> start with YES + return "YES Advocate" - def should_continue_risk_analysis(self, state: AgentState) -> str: - """Determine if risk analysis should continue.""" - if ( - state["risk_debate_state"]["count"] >= 3 * self.max_risk_discuss_rounds - ): # 3 rounds of back-and-forth between 3 agents + def should_continue_risk_analysis(self, state): + """Route 3-way risk debate. Unchanged from original.""" + count = state["risk_debate_state"]["count"] + if count >= 3 * self.max_risk_discuss_rounds: return "Risk Judge" - if state["risk_debate_state"]["latest_speaker"].startswith("Aggressive"): + latest = state["risk_debate_state"].get("latest_speaker", "") + if latest.startswith("Aggressive"): return "Conservative Analyst" - if state["risk_debate_state"]["latest_speaker"].startswith("Conservative"): + elif latest.startswith("Conservative"): return "Neutral Analyst" - return "Aggressive Analyst" + else: + return "Aggressive Analyst"