TradingAgents/tradingagents/graph/conditional_logic.py

137 lines
4.5 KiB
Python

# -*- coding: utf-8 -*-
# TradingAgentsX/graph/conditional_logic.py
from tradingagents.agents.utils.agent_states import AgentState
class ConditionalLogic:
"""
處理用於確定圖流程的條件邏輯。
這個類別定義了在圖中不同節點之間轉換的規則,
例如,決定下一個應該執行的代理或是否繼續一個循環。
"""
def __init__(self, max_debate_rounds=1, max_risk_discuss_rounds=1):
"""
使用設定參數進行初始化。
Args:
max_debate_rounds (int): 投資辯論的最大回合數。
max_risk_discuss_rounds (int): 風險討論的最大回合數。
"""
self.max_debate_rounds = max_debate_rounds
self.max_risk_discuss_rounds = max_risk_discuss_rounds
def should_continue_market(self, state: AgentState):
"""
判斷市場分析是否應該繼續。
如果最後一條訊息包含工具呼叫,則表示代理需要使用工具,
流程應該轉到市場工具節點。否則,分析完成。
Args:
state (AgentState): 當前的代理狀態。
Returns:
str: 下一個節點的名稱。
"""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools_market"
return "Msg Clear Market"
def should_continue_social(self, state: AgentState):
"""
判斷社群媒體分析是否應該繼續。
邏輯與 `should_continue_market` 類似。
Args:
state (AgentState): 當前的代理狀態。
Returns:
str: 下一個節點的名稱。
"""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools_social"
return "Msg Clear Social"
def should_continue_news(self, state: AgentState):
"""
判斷新聞分析是否應該繼續。
邏輯與 `should_continue_market` 類似。
Args:
state (AgentState): 當前的代理狀態。
Returns:
str: 下一個節點的名稱。
"""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools_news"
return "Msg Clear News"
def should_continue_fundamentals(self, state: AgentState):
"""
判斷基本面分析是否應該繼續。
邏輯與 `should_continue_market` 類似。
Args:
state (AgentState): 當前的代理狀態。
Returns:
str: 下一個節點的名稱。
"""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools_fundamentals"
return "Msg Clear Fundamentals"
def should_continue_debate(self, state: AgentState) -> str:
"""
判斷投資辯論是否應該繼續。
如果辯論回合數達到上限,則由研究經理做出最終決定。
否則,在看漲和看跌研究員之間輪流進行。
Args:
state (AgentState): 當前的代理狀態。
Returns:
str: 下一個節點的名稱。
"""
# 2 個代理之間的來回辯論
if (
state["investment_debate_state"]["count"] >= 2 * self.max_debate_rounds
):
return "Research Manager"
# 檢查中文前綴(因為研究員使用中文格式化響應)
if state["investment_debate_state"]["current_response"].startswith("看漲"):
return "Bear Researcher"
return "Bull Researcher"
def should_continue_risk_analysis(self, state: AgentState) -> str:
"""
判斷風險分析是否應該繼續。
如果討論回合數達到上限,則由風險裁判做出最終決定。
否則,在激進、保守和中立分析師之間輪流進行。
Args:
state (AgentState): 當前的代理狀態。
Returns:
str: 下一個節點的名稱。
"""
# 3 個代理之間的來回討論
if (
state["risk_debate_state"]["count"] >= 3 * self.max_risk_discuss_rounds
):
return "Risk Judge"
if state["risk_debate_state"]["latest_speaker"].startswith("Risky"):
return "Safe Analyst"
if state["risk_debate_state"]["latest_speaker"].startswith("Safe"):
return "Neutral Analyst"
return "Risky Analyst"