feat: update AgentState and conditional logic for Polymarket 3-way debate
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
12051c1570
commit
39cce0fb9b
|
|
@ -0,0 +1,40 @@
|
|||
"""Tests for updated agent state definitions."""
|
||||
|
||||
|
||||
def test_invest_debate_state_has_timing_fields():
|
||||
from tradingagents.agents.utils.agent_states import InvestDebateState
|
||||
keys = InvestDebateState.__annotations__
|
||||
assert "yes_history" in keys
|
||||
assert "no_history" in keys
|
||||
assert "timing_history" in keys
|
||||
assert "latest_speaker" in keys
|
||||
assert "current_yes_response" in keys
|
||||
assert "current_no_response" in keys
|
||||
assert "current_timing_response" in keys
|
||||
assert "bull_history" not in keys
|
||||
assert "bear_history" not in keys
|
||||
|
||||
|
||||
def test_agent_state_has_polymarket_fields():
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
keys = AgentState.__annotations__
|
||||
assert "event_id" in keys
|
||||
assert "event_question" in keys
|
||||
assert "odds_report" in keys
|
||||
assert "event_report" in keys
|
||||
assert "trader_plan" in keys
|
||||
assert "final_decision" in keys
|
||||
assert "company_of_interest" not in keys
|
||||
assert "market_report" not in keys
|
||||
assert "fundamentals_report" not in keys
|
||||
assert "trader_investment_plan" not in keys
|
||||
assert "final_trade_decision" not in keys
|
||||
|
||||
|
||||
def test_risk_debate_state_unchanged():
|
||||
from tradingagents.agents.utils.agent_states import RiskDebateState
|
||||
keys = RiskDebateState.__annotations__
|
||||
assert "aggressive_history" in keys
|
||||
assert "conservative_history" in keys
|
||||
assert "neutral_history" in keys
|
||||
assert "latest_speaker" in keys
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
"""Tests for conditional logic routing."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def test_debate_routes_yes_to_no():
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||
state = {"investment_debate_state": {"count": 1, "latest_speaker": "YES Advocate"}}
|
||||
assert cl.should_continue_debate(state) == "NO Advocate"
|
||||
|
||||
|
||||
def test_debate_routes_no_to_timing():
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||
state = {"investment_debate_state": {"count": 2, "latest_speaker": "NO Advocate"}}
|
||||
assert cl.should_continue_debate(state) == "Timing Advocate"
|
||||
|
||||
|
||||
def test_debate_routes_timing_to_yes():
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic(max_debate_rounds=2, max_risk_discuss_rounds=1)
|
||||
state = {"investment_debate_state": {"count": 3, "latest_speaker": "Timing Advocate"}}
|
||||
assert cl.should_continue_debate(state) == "YES Advocate"
|
||||
|
||||
|
||||
def test_debate_routes_to_manager_after_max_rounds():
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||
state = {"investment_debate_state": {"count": 3, "latest_speaker": "Timing Advocate"}}
|
||||
assert cl.should_continue_debate(state) == "Research Manager"
|
||||
|
||||
|
||||
def test_debate_initial_routes_to_yes():
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||
state = {"investment_debate_state": {"count": 0, "latest_speaker": ""}}
|
||||
assert cl.should_continue_debate(state) == "YES Advocate"
|
||||
|
||||
|
||||
def test_should_continue_odds_routes_to_tools():
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||
msg = MagicMock()
|
||||
msg.tool_calls = [{"name": "get_market_data", "args": {}}]
|
||||
state = {"messages": [msg]}
|
||||
assert cl.should_continue_odds(state) == "tools_odds"
|
||||
|
||||
|
||||
def test_should_continue_odds_routes_to_clear():
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||
msg = MagicMock()
|
||||
msg.tool_calls = []
|
||||
state = {"messages": [msg]}
|
||||
assert cl.should_continue_odds(state) == "Msg Clear Odds"
|
||||
|
||||
|
||||
def test_risk_routes_aggressive_to_conservative():
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||
state = {"risk_debate_state": {"count": 1, "latest_speaker": "Aggressive Analyst"}}
|
||||
assert cl.should_continue_risk_analysis(state) == "Conservative Analyst"
|
||||
|
||||
|
||||
def test_risk_routes_to_judge_after_max():
|
||||
from tradingagents.graph.conditional_logic import ConditionalLogic
|
||||
cl = ConditionalLogic(max_debate_rounds=1, max_risk_discuss_rounds=1)
|
||||
state = {"risk_debate_state": {"count": 3, "latest_speaker": "Neutral Analyst"}}
|
||||
assert cl.should_continue_risk_analysis(state) == "Risk Judge"
|
||||
|
|
@ -1,76 +1,47 @@
|
|||
from typing import Annotated, Sequence
|
||||
from datetime import date, timedelta, datetime
|
||||
from typing_extensions import TypedDict, Optional
|
||||
from langchain_openai import ChatOpenAI
|
||||
from tradingagents.agents import *
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langgraph.graph import END, StateGraph, START, MessagesState
|
||||
from langgraph.graph import MessagesState
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
# Researcher team state
|
||||
class InvestDebateState(TypedDict):
|
||||
bull_history: Annotated[
|
||||
str, "Bullish Conversation history"
|
||||
] # Bullish Conversation history
|
||||
bear_history: Annotated[
|
||||
str, "Bearish Conversation history"
|
||||
] # Bullish Conversation history
|
||||
history: Annotated[str, "Conversation history"] # Conversation history
|
||||
current_response: Annotated[str, "Latest response"] # Last response
|
||||
judge_decision: Annotated[str, "Final judge decision"] # Last response
|
||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
||||
"""State for the YES/NO/Timing investment debate."""
|
||||
yes_history: str
|
||||
no_history: str
|
||||
timing_history: str
|
||||
history: str
|
||||
current_yes_response: str
|
||||
current_no_response: str
|
||||
current_timing_response: str
|
||||
latest_speaker: str
|
||||
judge_decision: str
|
||||
count: int
|
||||
|
||||
|
||||
# Risk management team state
|
||||
class RiskDebateState(TypedDict):
|
||||
aggressive_history: Annotated[
|
||||
str, "Aggressive Agent's Conversation history"
|
||||
] # Conversation history
|
||||
conservative_history: Annotated[
|
||||
str, "Conservative Agent's Conversation history"
|
||||
] # Conversation history
|
||||
neutral_history: Annotated[
|
||||
str, "Neutral Agent's Conversation history"
|
||||
] # Conversation history
|
||||
history: Annotated[str, "Conversation history"] # Conversation history
|
||||
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
||||
current_aggressive_response: Annotated[
|
||||
str, "Latest response by the aggressive analyst"
|
||||
] # Last response
|
||||
current_conservative_response: Annotated[
|
||||
str, "Latest response by the conservative analyst"
|
||||
] # Last response
|
||||
current_neutral_response: Annotated[
|
||||
str, "Latest response by the neutral analyst"
|
||||
] # Last response
|
||||
judge_decision: Annotated[str, "Judge's decision"]
|
||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
||||
"""State for the Aggressive/Conservative/Neutral risk debate."""
|
||||
aggressive_history: str
|
||||
conservative_history: str
|
||||
neutral_history: str
|
||||
history: str
|
||||
latest_speaker: str
|
||||
current_aggressive_response: str
|
||||
current_conservative_response: str
|
||||
current_neutral_response: str
|
||||
judge_decision: str
|
||||
count: int
|
||||
|
||||
|
||||
class AgentState(MessagesState):
|
||||
company_of_interest: Annotated[str, "Company that we are interested in trading"]
|
||||
trade_date: Annotated[str, "What date we are trading at"]
|
||||
|
||||
sender: Annotated[str, "Agent that sent this message"]
|
||||
|
||||
# research step
|
||||
market_report: Annotated[str, "Report from the Market Analyst"]
|
||||
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
|
||||
news_report: Annotated[
|
||||
str, "Report from the News Researcher of current world affairs"
|
||||
]
|
||||
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
|
||||
|
||||
# researcher team discussion step
|
||||
investment_debate_state: Annotated[
|
||||
InvestDebateState, "Current state of the debate on if to invest or not"
|
||||
]
|
||||
investment_plan: Annotated[str, "Plan generated by the Analyst"]
|
||||
|
||||
trader_investment_plan: Annotated[str, "Plan generated by the Trader"]
|
||||
|
||||
# risk management team discussion step
|
||||
risk_debate_state: Annotated[
|
||||
RiskDebateState, "Current state of the debate on evaluating risk"
|
||||
]
|
||||
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
||||
"""Main agent state for Polymarket prediction analysis."""
|
||||
event_id: str
|
||||
event_question: str
|
||||
trade_date: str
|
||||
sender: str
|
||||
odds_report: str
|
||||
sentiment_report: str
|
||||
news_report: str
|
||||
event_report: str
|
||||
investment_debate_state: InvestDebateState
|
||||
investment_plan: str
|
||||
trader_plan: str
|
||||
risk_debate_state: RiskDebateState
|
||||
final_decision: str
|
||||
|
|
|
|||
|
|
@ -1,67 +1,64 @@
|
|||
# TradingAgents/graph/conditional_logic.py
|
||||
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
"""Conditional routing logic for the trading agents graph."""
|
||||
|
||||
|
||||
class ConditionalLogic:
|
||||
"""Handles conditional logic for determining graph flow."""
|
||||
"""Handles conditional routing decisions in the graph."""
|
||||
|
||||
def __init__(self, max_debate_rounds=1, max_risk_discuss_rounds=1):
|
||||
"""Initialize with configuration parameters."""
|
||||
self.max_debate_rounds = max_debate_rounds
|
||||
self.max_risk_discuss_rounds = max_risk_discuss_rounds
|
||||
|
||||
def should_continue_market(self, state: AgentState):
|
||||
"""Determine if market analysis should continue."""
|
||||
def should_continue_odds(self, state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_market"
|
||||
return "Msg Clear Market"
|
||||
return "tools_odds"
|
||||
return "Msg Clear Odds"
|
||||
|
||||
def should_continue_social(self, state: AgentState):
|
||||
"""Determine if social media analysis should continue."""
|
||||
def should_continue_social(self, state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_social"
|
||||
return "Msg Clear Social"
|
||||
|
||||
def should_continue_news(self, state: AgentState):
|
||||
"""Determine if news analysis should continue."""
|
||||
def should_continue_news(self, state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_news"
|
||||
return "Msg Clear News"
|
||||
|
||||
def should_continue_fundamentals(self, state: AgentState):
|
||||
"""Determine if fundamentals analysis should continue."""
|
||||
def should_continue_event(self, state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_fundamentals"
|
||||
return "Msg Clear Fundamentals"
|
||||
return "tools_event"
|
||||
return "Msg Clear Event"
|
||||
|
||||
def should_continue_debate(self, state: AgentState) -> str:
|
||||
"""Determine if debate should continue."""
|
||||
|
||||
if (
|
||||
state["investment_debate_state"]["count"] >= 2 * self.max_debate_rounds
|
||||
): # 3 rounds of back-and-forth between 2 agents
|
||||
def should_continue_debate(self, state):
|
||||
"""Route 3-way YES/NO/Timing debate. Mirrors risk debate pattern."""
|
||||
count = state["investment_debate_state"]["count"]
|
||||
if count >= 3 * self.max_debate_rounds:
|
||||
return "Research Manager"
|
||||
if state["investment_debate_state"]["current_response"].startswith("Bull"):
|
||||
return "Bear Researcher"
|
||||
return "Bull Researcher"
|
||||
latest = state["investment_debate_state"].get("latest_speaker", "")
|
||||
if latest.startswith("YES"):
|
||||
return "NO Advocate"
|
||||
elif latest.startswith("NO"):
|
||||
return "Timing Advocate"
|
||||
else:
|
||||
# Initial entry or after Timing -> start with YES
|
||||
return "YES Advocate"
|
||||
|
||||
def should_continue_risk_analysis(self, state: AgentState) -> str:
|
||||
"""Determine if risk analysis should continue."""
|
||||
if (
|
||||
state["risk_debate_state"]["count"] >= 3 * self.max_risk_discuss_rounds
|
||||
): # 3 rounds of back-and-forth between 3 agents
|
||||
def should_continue_risk_analysis(self, state):
|
||||
"""Route 3-way risk debate. Unchanged from original."""
|
||||
count = state["risk_debate_state"]["count"]
|
||||
if count >= 3 * self.max_risk_discuss_rounds:
|
||||
return "Risk Judge"
|
||||
if state["risk_debate_state"]["latest_speaker"].startswith("Aggressive"):
|
||||
latest = state["risk_debate_state"].get("latest_speaker", "")
|
||||
if latest.startswith("Aggressive"):
|
||||
return "Conservative Analyst"
|
||||
if state["risk_debate_state"]["latest_speaker"].startswith("Conservative"):
|
||||
elif latest.startswith("Conservative"):
|
||||
return "Neutral Analyst"
|
||||
return "Aggressive Analyst"
|
||||
else:
|
||||
return "Aggressive Analyst"
|
||||
|
|
|
|||
Loading…
Reference in New Issue