TradingAgents/tradingagents/graph/setup.py

324 lines
13 KiB
Python

# TradingAgents/graph/setup.py
from typing import Dict, Any
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
from tradingagents.agents.utils.agent_states import (
AgentState,
SocialAnalystState,
NewsAnalystState,
FundamentalsAnalystState
)
from .enhanced_conditional_logic import EnhancedConditionalLogic
class GraphSetup:
"""Handles the setup and configuration of the agent graph."""
def __init__(
self,
quick_thinking_llm: ChatOpenAI,
deep_thinking_llm: ChatOpenAI,
tool_nodes: Dict[str, ToolNode],
bull_memory,
bear_memory,
trader_memory,
invest_judge_memory,
risk_manager_memory,
conditional_logic: EnhancedConditionalLogic,
):
"""Initialize with required components."""
self.quick_thinking_llm = quick_thinking_llm
self.deep_thinking_llm = deep_thinking_llm
self.tool_nodes = tool_nodes
self.bull_memory = bull_memory
self.bear_memory = bear_memory
self.trader_memory = trader_memory
self.invest_judge_memory = invest_judge_memory
self.risk_manager_memory = risk_manager_memory
self.conditional_logic = conditional_logic
def build_analyst_subgraph(self, analyst_node, delete_node, tool_node, check_condition, name, state_schema):
"""Builder for Analyst Subgraphs (Isolation Sandbox).
Each analyst runs in its own StateGraph to prevent sharing the 'messages' list
with other parallel analysts.
Flow: START -> Msg Clear (Init) -> Analyst -> [Tools -> Analyst] -> END
Args:
analyst_node: The main agent function
delete_node: Function to clear messages (used as init)
tool_node: The tool execution node
check_condition: Function to decide loop vs end
name: Name of the analyst (for logging/labels)
state_schema: The strictly typed State class for this subgraph
"""
# USE STRICT SCHEMA HERE instead of AgentState
subgraph = StateGraph(state_schema)
# Add Nodes
# We invoke 'delete_node' first to ensure a CLEAN SLATE for this subgraph.
# This effectively isolates the message history.
subgraph.add_node("Init_Clear", delete_node)
subgraph.add_node("Analyst", analyst_node)
subgraph.add_node("Tools", tool_node)
# Edges
# 1. START -> Clear (Wipe parent messages to avoid contamination)
subgraph.add_edge(START, "Init_Clear")
# 2. Clear -> Analyst
subgraph.add_edge("Init_Clear", "Analyst")
# 3. Analyst -> Conditional
subgraph.add_conditional_edges(
"Analyst",
check_condition,
{
# Map the string return values of condition to our internal nodes
f"tools_{name}": "Tools", # Map external name to internal "Tools"
f"Msg Clear {name.capitalize()}": END # Map external finish to END
}
)
# 4. Tools -> Analyst
subgraph.add_edge("Tools", "Analyst")
return subgraph.compile()
def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"]
):
"""Set up and compile the agent workflow graph.
Args:
selected_analysts (list): List of analyst types to include. Options are:
- "market": Market analyst
- "social": Social media analyst
- "news": News analyst
- "fundamentals": Fundamentals analyst
"""
if len(selected_analysts) == 0:
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
# Create analyst nodes
analyst_nodes = {}
delete_nodes = {}
tool_nodes = {}
# FORCE MARKET ANALYST (MANDATORY)
# It must enable Regime Detection before any other analyst runs.
# Remove 'market' from selected list to avoid duplication if user selected it.
# We will add it manually as the first node.
other_analysts = [a for a in selected_analysts if a != "market"]
# MARKET ANALYST (Always Created)
analyst_nodes["market"] = create_market_analyst(self.quick_thinking_llm)
delete_nodes["market"] = create_msg_delete()
tool_nodes["market"] = self.tool_nodes["market"]
# Loop through other optional analysts (Social, News, Fundamentals)
if "social" in selected_analysts:
analyst_nodes["social"] = create_social_media_analyst(
self.quick_thinking_llm
)
delete_nodes["social"] = create_msg_delete()
tool_nodes["social"] = self.tool_nodes["social"]
if "news" in selected_analysts:
analyst_nodes["news"] = create_news_analyst(
self.quick_thinking_llm
)
delete_nodes["news"] = create_msg_delete()
tool_nodes["news"] = self.tool_nodes["news"]
if "fundamentals" in selected_analysts:
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
self.quick_thinking_llm
)
delete_nodes["fundamentals"] = create_msg_delete()
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
# Create researcher and manager nodes
bull_researcher_node = create_bull_researcher(
self.quick_thinking_llm, self.bull_memory
)
bear_researcher_node = create_bear_researcher(
self.quick_thinking_llm, self.bear_memory
)
research_manager_node = create_research_manager(
self.deep_thinking_llm, self.invest_judge_memory
)
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
# Create risk analysis nodes
risky_analyst = create_risky_debator(self.quick_thinking_llm)
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
safe_analyst = create_safe_debator(self.quick_thinking_llm)
risk_manager_node = create_risk_manager(
self.deep_thinking_llm, self.risk_manager_memory
)
# Create workflow
workflow = StateGraph(AgentState)
# Add analyst nodes to the graph
# Add analyst nodes to the graph
# 1. Add Market Analyst (Mandatory)
workflow.add_node("Market Analyst", analyst_nodes["market"])
workflow.add_node("Msg Clear Market", delete_nodes["market"])
workflow.add_node("tools_market", tool_nodes["market"])
# 2. Add Other Analysts (SUBGRAPHS)
# Map analyst types to their Strict State Schemas
schema_map = {
"social": SocialAnalystState,
"news": NewsAnalystState,
"fundamentals": FundamentalsAnalystState
}
for analyst_type in other_analysts:
if analyst_type in analyst_nodes:
# Build the isolated subgraph for this analyst
# START -> Clear -> Analyst <-> Tools -> END
analyst_subgraph = self.build_analyst_subgraph(
analyst_node=analyst_nodes[analyst_type],
delete_node=delete_nodes[analyst_type],
tool_node=tool_nodes[analyst_type],
check_condition=getattr(self.conditional_logic, f"should_continue_{analyst_type}"),
name=analyst_type,
state_schema=schema_map.get(analyst_type, AgentState) # Fallback to AgentState if undefined
)
# Add the SUBGRAPH as a single node to the main workflow
# The node name is "{Type} Analyst" e.g., "Social Analyst"
# LangGraph handles the state passing (AgentState -> Subgraph -> AgentState update)
workflow.add_node(f"{analyst_type.capitalize()} Analyst", analyst_subgraph)
# Add other nodes
workflow.add_node("Bull Researcher", bull_researcher_node)
workflow.add_node("Bear Researcher", bear_researcher_node)
workflow.add_node("Research Manager", research_manager_node)
workflow.add_node("Trader", trader_node)
workflow.add_node("Risky Analyst", risky_analyst)
workflow.add_node("Neutral Analyst", neutral_analyst)
workflow.add_node("Safe Analyst", safe_analyst)
workflow.add_node("Risk Judge", risk_manager_node)
# Define edges
# Define edges
# 1. START -> Market Analyst (Always)
workflow.add_edge(START, "Market Analyst")
# 2. Market Analyst -> Tools -> Clear
workflow.add_conditional_edges(
"Market Analyst",
self.conditional_logic.should_continue_market,
["tools_market", "Msg Clear Market"],
)
workflow.add_edge("tools_market", "Market Analyst")
# Compile and return workflow
# --- PARALLEL EXECUTION ARCHITECTURE (FAN-OUT / FAN-IN) ---
# 3. FAN-OUT: Market Analyst -> [Social, News, Fundamentals] (Parallel)
# Instead of a chain, we connect "Msg Clear Market" to ALL selected analysts.
if len(other_analysts) > 0:
for analyst_type in other_analysts:
workflow.add_edge("Msg Clear Market", f"{analyst_type.capitalize()} Analyst")
else:
# Fallback for simple runs
workflow.add_edge("Msg Clear Market", "Bull Researcher")
# 4. PARALLEL BRANCHES & FAN-IN
# Create Sync Node to wait for all parallel branches
def analyst_sync_node(state: AgentState):
return {} # Identity node (Pass-through)
workflow.add_node("Analyst Sync", analyst_sync_node)
for analyst_type in other_analysts:
# Connect Subgraph output directly to Sync Node
# The subgraph encapsulates the work and ends at END.
# In LangGraph, when a node (subgraph) finishes, it transitions to the next edge.
workflow.add_edge(f"{analyst_type.capitalize()} Analyst", "Analyst Sync")
# 5. SYNC -> DEBATE
# Once all parallel branches hit the Sync node, proceed to Bull Researcher
workflow.add_edge("Analyst Sync", "Bull Researcher")
# Add remaining edges
workflow.add_conditional_edges(
"Bull Researcher",
self.conditional_logic.should_continue_debate_with_validation,
{
"Bear Researcher": "Bear Researcher",
"Bull Researcher": "Bull Researcher", # REJECTION LOOP
"Research Manager": "Research Manager",
},
)
workflow.add_conditional_edges(
"Bear Researcher",
self.conditional_logic.should_continue_debate_with_validation,
{
"Bull Researcher": "Bull Researcher",
"Bear Researcher": "Bear Researcher", # REJECTION LOOP
"Research Manager": "Research Manager",
},
)
workflow.add_edge("Research Manager", "Trader")
# --- NEW PARALLEL RISK ARCHITECTURE (STAR TOPOLOGY) ---
# 1. FAN-OUT: Trader -> All 3 Analysts
# The Trader's plan is broadcast to all three critics simultaneously.
workflow.add_edge("Trader", "Risky Analyst")
workflow.add_edge("Trader", "Safe Analyst")
workflow.add_edge("Trader", "Neutral Analyst")
# 2. DEFINE SYNC NODE (The Barrier)
# This node does nothing but wait for all upstream branches to finish.
def risk_sync_node(state: AgentState):
return {} # Pass-through, just acts as a synchronization point
workflow.add_node("Risk Sync", risk_sync_node)
# 3. FAN-IN: Analysts -> Sync
# All three must finish before the token moves to 'Risk Sync'
workflow.add_edge("Risky Analyst", "Risk Sync")
workflow.add_edge("Safe Analyst", "Risk Sync")
workflow.add_edge("Neutral Analyst", "Risk Sync")
# 4. SYNC -> JUDGE
# The Judge now runs ONCE, seeing the merged state of all 3 critics.
workflow.add_edge("Risk Sync", "Risk Judge")
# 5. JUDGE -> END (or Enhanced Logic)
if hasattr(self.conditional_logic, 'should_proceed_after_risk_gate'):
workflow.add_conditional_edges(
"Risk Judge",
self.conditional_logic.should_proceed_after_risk_gate,
{
"END": END,
"Market Analyst": "Market Analyst",
"Risk Manager Revision": "Trader", # Send back to Trader to fix plan
"Execute Trade": END
}
)
else:
workflow.add_edge("Risk Judge", END)
# Compile and return
return workflow.compile()