feat: update AgentState and conditional logic for Polymarket 3-way debate
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
12051c1570
commit
39cce0fb9b
|
|
@ -0,0 +1,40 @@
|
||||||
|
"""Tests for updated agent state definitions."""
|
||||||
|
|
||||||
|
|
||||||
|
def test_invest_debate_state_has_timing_fields():
|
||||||
|
from tradingagents.agents.utils.agent_states import InvestDebateState
|
||||||
|
keys = InvestDebateState.__annotations__
|
||||||
|
assert "yes_history" in keys
|
||||||
|
assert "no_history" in keys
|
||||||
|
assert "timing_history" in keys
|
||||||
|
assert "latest_speaker" in keys
|
||||||
|
assert "current_yes_response" in keys
|
||||||
|
assert "current_no_response" in keys
|
||||||
|
assert "current_timing_response" in keys
|
||||||
|
assert "bull_history" not in keys
|
||||||
|
assert "bear_history" not in keys
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_state_has_polymarket_fields():
|
||||||
|
from tradingagents.agents.utils.agent_states import AgentState
|
||||||
|
keys = AgentState.__annotations__
|
||||||
|
assert "event_id" in keys
|
||||||
|
assert "event_question" in keys
|
||||||
|
assert "odds_report" in keys
|
||||||
|
assert "event_report" in keys
|
||||||
|
assert "trader_plan" in keys
|
||||||
|
assert "final_decision" in keys
|
||||||
|
assert "company_of_interest" not in keys
|
||||||
|
assert "market_report" not in keys
|
||||||
|
assert "fundamentals_report" not in keys
|
||||||
|
assert "trader_investment_plan" not in keys
|
||||||
|
assert "final_trade_decision" not in keys
|
||||||
|
|
||||||
|
|
||||||
|
def test_risk_debate_state_unchanged():
|
||||||
|
from tradingagents.agents.utils.agent_states import RiskDebateState
|
||||||
|
keys = RiskDebateState.__annotations__
|
||||||
|
assert "aggressive_history" in keys
|
||||||
|
assert "conservative_history" in keys
|
||||||
|
assert "neutral_history" in keys
|
||||||
|
assert "latest_speaker" in keys
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""Tests for conditional logic routing."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
def test_debate_routes_yes_to_no():
|
||||||
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||||
|
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||||
|
state = {"investment_debate_state": {"count": 1, "latest_speaker": "YES Advocate"}}
|
||||||
|
assert cl.should_continue_debate(state) == "NO Advocate"
|
||||||
|
|
||||||
|
|
||||||
|
def test_debate_routes_no_to_timing():
|
||||||
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||||
|
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||||
|
state = {"investment_debate_state": {"count": 2, "latest_speaker": "NO Advocate"}}
|
||||||
|
assert cl.should_continue_debate(state) == "Timing Advocate"
|
||||||
|
|
||||||
|
|
||||||
|
def test_debate_routes_timing_to_yes():
|
||||||
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||||
|
cl = ConditionalLogic(max_debate_rounds=2, max_risk_discuss_rounds=1)
|
||||||
|
state = {"investment_debate_state": {"count": 3, "latest_speaker": "Timing Advocate"}}
|
||||||
|
assert cl.should_continue_debate(state) == "YES Advocate"
|
||||||
|
|
||||||
|
|
||||||
|
def test_debate_routes_to_manager_after_max_rounds():
|
||||||
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||||
|
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||||
|
state = {"investment_debate_state": {"count": 3, "latest_speaker": "Timing Advocate"}}
|
||||||
|
assert cl.should_continue_debate(state) == "Research Manager"
|
||||||
|
|
||||||
|
|
||||||
|
def test_debate_initial_routes_to_yes():
|
||||||
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||||
|
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||||
|
state = {"investment_debate_state": {"count": 0, "latest_speaker": ""}}
|
||||||
|
assert cl.should_continue_debate(state) == "YES Advocate"
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_continue_odds_routes_to_tools():
|
||||||
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||||
|
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.tool_calls = [{"name": "get_market_data", "args": {}}]
|
||||||
|
state = {"messages": [msg]}
|
||||||
|
assert cl.should_continue_odds(state) == "tools_odds"
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_continue_odds_routes_to_clear():
|
||||||
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||||
|
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.tool_calls = []
|
||||||
|
state = {"messages": [msg]}
|
||||||
|
assert cl.should_continue_odds(state) == "Msg Clear Odds"
|
||||||
|
|
||||||
|
|
||||||
|
def test_risk_routes_aggressive_to_conservative():
|
||||||
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||||
|
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||||
|
state = {"risk_debate_state": {"count": 1, "latest_speaker": "Aggressive Analyst"}}
|
||||||
|
assert cl.should_continue_risk_analysis(state) == "Conservative Analyst"
|
||||||
|
|
||||||
|
|
||||||
|
def test_risk_routes_to_judge_after_max():
|
||||||
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||||
|
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||||
|
state = {"risk_debate_state": {"count": 3, "latest_speaker": "Neutral Analyst"}}
|
||||||
|
assert cl.should_continue_risk_analysis(state) == "Risk Judge"
|
||||||
|
|
@ -1,76 +1,47 @@
|
||||||
from typing import Annotated, Sequence
|
from langgraph.graph import MessagesState
|
||||||
from datetime import date, timedelta, datetime
|
from typing_extensions import TypedDict
|
||||||
from typing_extensions import TypedDict, Optional
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from tradingagents.agents import *
|
|
||||||
from langgraph.prebuilt import ToolNode
|
|
||||||
from langgraph.graph import END, StateGraph, START, MessagesState
|
|
||||||
|
|
||||||
|
|
||||||
# Researcher team state
|
|
||||||
class InvestDebateState(TypedDict):
|
class InvestDebateState(TypedDict):
|
||||||
bull_history: Annotated[
|
"""State for the YES/NO/Timing investment debate."""
|
||||||
str, "Bullish Conversation history"
|
yes_history: str
|
||||||
] # Bullish Conversation history
|
no_history: str
|
||||||
bear_history: Annotated[
|
timing_history: str
|
||||||
str, "Bearish Conversation history"
|
history: str
|
||||||
] # Bullish Conversation history
|
current_yes_response: str
|
||||||
history: Annotated[str, "Conversation history"] # Conversation history
|
current_no_response: str
|
||||||
current_response: Annotated[str, "Latest response"] # Last response
|
current_timing_response: str
|
||||||
judge_decision: Annotated[str, "Final judge decision"] # Last response
|
latest_speaker: str
|
||||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
judge_decision: str
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
# Risk management team state
|
|
||||||
class RiskDebateState(TypedDict):
|
class RiskDebateState(TypedDict):
|
||||||
aggressive_history: Annotated[
|
"""State for the Aggressive/Conservative/Neutral risk debate."""
|
||||||
str, "Aggressive Agent's Conversation history"
|
aggressive_history: str
|
||||||
] # Conversation history
|
conservative_history: str
|
||||||
conservative_history: Annotated[
|
neutral_history: str
|
||||||
str, "Conservative Agent's Conversation history"
|
history: str
|
||||||
] # Conversation history
|
latest_speaker: str
|
||||||
neutral_history: Annotated[
|
current_aggressive_response: str
|
||||||
str, "Neutral Agent's Conversation history"
|
current_conservative_response: str
|
||||||
] # Conversation history
|
current_neutral_response: str
|
||||||
history: Annotated[str, "Conversation history"] # Conversation history
|
judge_decision: str
|
||||||
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
count: int
|
||||||
current_aggressive_response: Annotated[
|
|
||||||
str, "Latest response by the aggressive analyst"
|
|
||||||
] # Last response
|
|
||||||
current_conservative_response: Annotated[
|
|
||||||
str, "Latest response by the conservative analyst"
|
|
||||||
] # Last response
|
|
||||||
current_neutral_response: Annotated[
|
|
||||||
str, "Latest response by the neutral analyst"
|
|
||||||
] # Last response
|
|
||||||
judge_decision: Annotated[str, "Judge's decision"]
|
|
||||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
|
||||||
|
|
||||||
|
|
||||||
class AgentState(MessagesState):
|
class AgentState(MessagesState):
|
||||||
company_of_interest: Annotated[str, "Company that we are interested in trading"]
|
"""Main agent state for Polymarket prediction analysis."""
|
||||||
trade_date: Annotated[str, "What date we are trading at"]
|
event_id: str
|
||||||
|
event_question: str
|
||||||
sender: Annotated[str, "Agent that sent this message"]
|
trade_date: str
|
||||||
|
sender: str
|
||||||
# research step
|
odds_report: str
|
||||||
market_report: Annotated[str, "Report from the Market Analyst"]
|
sentiment_report: str
|
||||||
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
|
news_report: str
|
||||||
news_report: Annotated[
|
event_report: str
|
||||||
str, "Report from the News Researcher of current world affairs"
|
investment_debate_state: InvestDebateState
|
||||||
]
|
investment_plan: str
|
||||||
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
|
trader_plan: str
|
||||||
|
risk_debate_state: RiskDebateState
|
||||||
# researcher team discussion step
|
final_decision: str
|
||||||
investment_debate_state: Annotated[
|
|
||||||
InvestDebateState, "Current state of the debate on if to invest or not"
|
|
||||||
]
|
|
||||||
investment_plan: Annotated[str, "Plan generated by the Analyst"]
|
|
||||||
|
|
||||||
trader_investment_plan: Annotated[str, "Plan generated by the Trader"]
|
|
||||||
|
|
||||||
# risk management team discussion step
|
|
||||||
risk_debate_state: Annotated[
|
|
||||||
RiskDebateState, "Current state of the debate on evaluating risk"
|
|
||||||
]
|
|
||||||
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
|
||||||
|
|
|
||||||
|
|
@ -1,67 +1,64 @@
|
||||||
# TradingAgents/graph/conditional_logic.py
|
"""Conditional routing logic for the trading agents graph."""
|
||||||
|
|
||||||
from tradingagents.agents.utils.agent_states import AgentState
|
|
||||||
|
|
||||||
|
|
||||||
class ConditionalLogic:
|
class ConditionalLogic:
|
||||||
"""Handles conditional logic for determining graph flow."""
|
"""Handles conditional routing decisions in the graph."""
|
||||||
|
|
||||||
def __init__(self, max_debate_rounds=1, max_risk_discuss_rounds=1):
|
def __init__(self, max_debate_rounds=1, max_risk_discuss_rounds=1):
|
||||||
"""Initialize with configuration parameters."""
|
|
||||||
self.max_debate_rounds = max_debate_rounds
|
self.max_debate_rounds = max_debate_rounds
|
||||||
self.max_risk_discuss_rounds = max_risk_discuss_rounds
|
self.max_risk_discuss_rounds = max_risk_discuss_rounds
|
||||||
|
|
||||||
def should_continue_market(self, state: AgentState):
|
def should_continue_odds(self, state):
|
||||||
"""Determine if market analysis should continue."""
|
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
if last_message.tool_calls:
|
if last_message.tool_calls:
|
||||||
return "tools_market"
|
return "tools_odds"
|
||||||
return "Msg Clear Market"
|
return "Msg Clear Odds"
|
||||||
|
|
||||||
def should_continue_social(self, state: AgentState):
|
def should_continue_social(self, state):
|
||||||
"""Determine if social media analysis should continue."""
|
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
if last_message.tool_calls:
|
if last_message.tool_calls:
|
||||||
return "tools_social"
|
return "tools_social"
|
||||||
return "Msg Clear Social"
|
return "Msg Clear Social"
|
||||||
|
|
||||||
def should_continue_news(self, state: AgentState):
|
def should_continue_news(self, state):
|
||||||
"""Determine if news analysis should continue."""
|
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
if last_message.tool_calls:
|
if last_message.tool_calls:
|
||||||
return "tools_news"
|
return "tools_news"
|
||||||
return "Msg Clear News"
|
return "Msg Clear News"
|
||||||
|
|
||||||
def should_continue_fundamentals(self, state: AgentState):
|
def should_continue_event(self, state):
|
||||||
"""Determine if fundamentals analysis should continue."""
|
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
if last_message.tool_calls:
|
if last_message.tool_calls:
|
||||||
return "tools_fundamentals"
|
return "tools_event"
|
||||||
return "Msg Clear Fundamentals"
|
return "Msg Clear Event"
|
||||||
|
|
||||||
def should_continue_debate(self, state: AgentState) -> str:
|
def should_continue_debate(self, state):
|
||||||
"""Determine if debate should continue."""
|
"""Route 3-way YES/NO/Timing debate. Mirrors risk debate pattern."""
|
||||||
|
count = state["investment_debate_state"]["count"]
|
||||||
if (
|
if count >= 3 * self.max_debate_rounds:
|
||||||
state["investment_debate_state"]["count"] >= 2 * self.max_debate_rounds
|
|
||||||
): # 3 rounds of back-and-forth between 2 agents
|
|
||||||
return "Research Manager"
|
return "Research Manager"
|
||||||
if state["investment_debate_state"]["current_response"].startswith("Bull"):
|
latest = state["investment_debate_state"].get("latest_speaker", "")
|
||||||
return "Bear Researcher"
|
if latest.startswith("YES"):
|
||||||
return "Bull Researcher"
|
return "NO Advocate"
|
||||||
|
elif latest.startswith("NO"):
|
||||||
|
return "Timing Advocate"
|
||||||
|
else:
|
||||||
|
# Initial entry or after Timing -> start with YES
|
||||||
|
return "YES Advocate"
|
||||||
|
|
||||||
def should_continue_risk_analysis(self, state: AgentState) -> str:
|
def should_continue_risk_analysis(self, state):
|
||||||
"""Determine if risk analysis should continue."""
|
"""Route 3-way risk debate. Unchanged from original."""
|
||||||
if (
|
count = state["risk_debate_state"]["count"]
|
||||||
state["risk_debate_state"]["count"] >= 3 * self.max_risk_discuss_rounds
|
if count >= 3 * self.max_risk_discuss_rounds:
|
||||||
): # 3 rounds of back-and-forth between 3 agents
|
|
||||||
return "Risk Judge"
|
return "Risk Judge"
|
||||||
if state["risk_debate_state"]["latest_speaker"].startswith("Aggressive"):
|
latest = state["risk_debate_state"].get("latest_speaker", "")
|
||||||
|
if latest.startswith("Aggressive"):
|
||||||
return "Conservative Analyst"
|
return "Conservative Analyst"
|
||||||
if state["risk_debate_state"]["latest_speaker"].startswith("Conservative"):
|
elif latest.startswith("Conservative"):
|
||||||
return "Neutral Analyst"
|
return "Neutral Analyst"
|
||||||
return "Aggressive Analyst"
|
else:
|
||||||
|
return "Aggressive Analyst"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue