feat: parallelize analyst agents for ~3x speedup
Run all 4 analysts (Market, Social, News, Fundamentals) concurrently using asyncio.gather instead of sequentially. Each analyst gets its own isolated message state and tool-calling loop. Cuts analyst phase from ~8-9 min to ~2-3 min (total analysis from ~11 min to ~4-5 min). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
f5519b9efe
commit
223879bc04
23
app.py
23
app.py
|
|
@ -53,6 +53,7 @@ def build_config():
|
|||
"fundamental_data": "yfinance",
|
||||
"news_data": "yfinance",
|
||||
}
|
||||
config["parallel_analysts"] = True
|
||||
return config
|
||||
|
||||
|
||||
|
|
@ -122,6 +123,28 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
|||
final_state = None
|
||||
prev_statuses = {}
|
||||
|
||||
# Emit all analysts as "in_progress" immediately (they run in parallel)
|
||||
analyst_name_map = {
|
||||
"market": "Market Analyst",
|
||||
"social": "Social Analyst",
|
||||
"news": "News Analyst",
|
||||
"fundamentals": "Fundamentals Analyst",
|
||||
}
|
||||
for analyst_type in selected_analysts:
|
||||
agent_name = analyst_name_map[analyst_type]
|
||||
buf.update_agent_status(agent_name, "in_progress")
|
||||
st = get_stats_dict(stats_handler, buf, start_time)
|
||||
evt = {
|
||||
"type": "agent_update",
|
||||
"agent": agent_name,
|
||||
"stage": "analysts",
|
||||
"status": "in_progress",
|
||||
"stats": st,
|
||||
}
|
||||
state["events"].append(evt)
|
||||
await q.put(evt)
|
||||
prev_statuses[agent_name] = "in_progress"
|
||||
|
||||
try:
|
||||
async for chunk in graph.graph.astream(init_state, **args):
|
||||
final_state = chunk
|
||||
|
|
|
|||
|
|
@ -0,0 +1,76 @@
|
|||
"""Parallel analyst execution for TradingAgents.
|
||||
|
||||
Runs all analyst agents (Market, Social, News, Fundamentals) concurrently
|
||||
instead of sequentially, cutting the analyst phase from ~8-9 min to ~2-3 min.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||
|
||||
|
||||
def create_parallel_analyst_node(analyst_fns, tool_nodes, selected_analysts):
|
||||
"""Create a single LangGraph node that runs all analysts in parallel.
|
||||
|
||||
Each analyst gets its own isolated message state and runs its complete
|
||||
tool-calling loop independently. Results are merged at the end.
|
||||
|
||||
Args:
|
||||
analyst_fns: dict mapping analyst type (e.g. "market") to node function
|
||||
tool_nodes: dict mapping analyst type to ToolNode instance
|
||||
selected_analysts: list of analyst types to run
|
||||
"""
|
||||
|
||||
async def parallel_analysts_node(state):
|
||||
"""Run all analysts concurrently and merge their reports."""
|
||||
|
||||
async def run_single(analyst_type):
|
||||
"""Run one analyst through its complete tool-calling loop."""
|
||||
fn = analyst_fns[analyst_type]
|
||||
tn = tool_nodes[analyst_type]
|
||||
|
||||
# Each analyst gets its own isolated message state
|
||||
local_state = {
|
||||
"messages": list(state["messages"]),
|
||||
"trade_date": state["trade_date"],
|
||||
"company_of_interest": state["company_of_interest"],
|
||||
}
|
||||
|
||||
result = {}
|
||||
for _ in range(10): # safety limit on tool rounds
|
||||
result = await asyncio.to_thread(fn, local_state)
|
||||
ai_msg = result["messages"][0]
|
||||
local_state["messages"] = local_state["messages"] + [ai_msg]
|
||||
|
||||
if not ai_msg.tool_calls:
|
||||
break
|
||||
|
||||
# Process tool calls
|
||||
tool_result = await asyncio.to_thread(tn.invoke, local_state)
|
||||
local_state["messages"] = (
|
||||
local_state["messages"] + tool_result["messages"]
|
||||
)
|
||||
|
||||
# Return only report fields (not messages)
|
||||
return {k: v for k, v in result.items() if k != "messages"}
|
||||
|
||||
# Run all analysts concurrently
|
||||
tasks = [run_single(at) for at in selected_analysts if at in analyst_fns]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Merge all report fields
|
||||
merged = {}
|
||||
for r in results:
|
||||
merged.update(r)
|
||||
|
||||
# Clear messages and add placeholder (same as Msg Clear nodes)
|
||||
messages = state.get("messages", [])
|
||||
removal_ops = [
|
||||
RemoveMessage(id=m.id)
|
||||
for m in messages
|
||||
if hasattr(m, "id") and m.id
|
||||
]
|
||||
merged["messages"] = removal_ops + [HumanMessage(content="Continue")]
|
||||
|
||||
return merged
|
||||
|
||||
return parallel_analysts_node
|
||||
|
|
@ -9,6 +9,7 @@ from tradingagents.agents import *
|
|||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .parallel_analysts import create_parallel_analyst_node
|
||||
|
||||
|
||||
class GraphSetup:
|
||||
|
|
@ -38,7 +39,8 @@ class GraphSetup:
|
|||
self.conditional_logic = conditional_logic
|
||||
|
||||
def setup_graph(
|
||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
||||
self, selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
parallel=False,
|
||||
):
|
||||
"""Set up and compile the agent workflow graph.
|
||||
|
||||
|
|
@ -48,11 +50,12 @@ class GraphSetup:
|
|||
- "social": Social media analyst
|
||||
- "news": News analyst
|
||||
- "fundamentals": Fundamentals analyst
|
||||
parallel (bool): Run analysts in parallel instead of sequentially.
|
||||
"""
|
||||
if len(selected_analysts) == 0:
|
||||
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
||||
|
||||
# Create analyst nodes
|
||||
# Create analyst node functions and tool nodes
|
||||
analyst_nodes = {}
|
||||
delete_nodes = {}
|
||||
tool_nodes = {}
|
||||
|
|
@ -108,13 +111,20 @@ class GraphSetup:
|
|||
# Create workflow
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Add analyst nodes to the graph
|
||||
for analyst_type, node in analyst_nodes.items():
|
||||
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
|
||||
workflow.add_node(
|
||||
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type]
|
||||
if parallel:
|
||||
# Single node runs all analysts concurrently
|
||||
parallel_node = create_parallel_analyst_node(
|
||||
analyst_nodes, tool_nodes, selected_analysts
|
||||
)
|
||||
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
|
||||
workflow.add_node("Parallel Analysts", parallel_node)
|
||||
else:
|
||||
# Add analyst nodes individually for sequential execution
|
||||
for analyst_type, node in analyst_nodes.items():
|
||||
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
|
||||
workflow.add_node(
|
||||
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type]
|
||||
)
|
||||
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
|
||||
|
||||
# Add other nodes
|
||||
workflow.add_node("Bull Researcher", bull_researcher_node)
|
||||
|
|
@ -127,32 +137,34 @@ class GraphSetup:
|
|||
workflow.add_node("Risk Judge", risk_manager_node)
|
||||
|
||||
# Define edges
|
||||
# Start with the first analyst
|
||||
first_analyst = selected_analysts[0]
|
||||
workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst")
|
||||
if parallel:
|
||||
# Parallel: START → Parallel Analysts → Bull Researcher
|
||||
workflow.add_edge(START, "Parallel Analysts")
|
||||
workflow.add_edge("Parallel Analysts", "Bull Researcher")
|
||||
else:
|
||||
# Sequential: START → Analyst 1 → ... → Analyst N → Bull Researcher
|
||||
first_analyst = selected_analysts[0]
|
||||
workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst")
|
||||
|
||||
# Connect analysts in sequence
|
||||
for i, analyst_type in enumerate(selected_analysts):
|
||||
current_analyst = f"{analyst_type.capitalize()} Analyst"
|
||||
current_tools = f"tools_{analyst_type}"
|
||||
current_clear = f"Msg Clear {analyst_type.capitalize()}"
|
||||
for i, analyst_type in enumerate(selected_analysts):
|
||||
current_analyst = f"{analyst_type.capitalize()} Analyst"
|
||||
current_tools = f"tools_{analyst_type}"
|
||||
current_clear = f"Msg Clear {analyst_type.capitalize()}"
|
||||
|
||||
# Add conditional edges for current analyst
|
||||
workflow.add_conditional_edges(
|
||||
current_analyst,
|
||||
getattr(self.conditional_logic, f"should_continue_{analyst_type}"),
|
||||
[current_tools, current_clear],
|
||||
)
|
||||
workflow.add_edge(current_tools, current_analyst)
|
||||
workflow.add_conditional_edges(
|
||||
current_analyst,
|
||||
getattr(self.conditional_logic, f"should_continue_{analyst_type}"),
|
||||
[current_tools, current_clear],
|
||||
)
|
||||
workflow.add_edge(current_tools, current_analyst)
|
||||
|
||||
# Connect to next analyst or to Bull Researcher if this is the last analyst
|
||||
if i < len(selected_analysts) - 1:
|
||||
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
|
||||
workflow.add_edge(current_clear, next_analyst)
|
||||
else:
|
||||
workflow.add_edge(current_clear, "Bull Researcher")
|
||||
if i < len(selected_analysts) - 1:
|
||||
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
|
||||
workflow.add_edge(current_clear, next_analyst)
|
||||
else:
|
||||
workflow.add_edge(current_clear, "Bull Researcher")
|
||||
|
||||
# Add remaining edges
|
||||
# Add remaining edges (same for both modes)
|
||||
workflow.add_conditional_edges(
|
||||
"Bull Researcher",
|
||||
self.conditional_logic.should_continue_debate,
|
||||
|
|
|
|||
|
|
@ -127,8 +127,9 @@ class TradingAgentsGraph:
|
|||
self.ticker = None
|
||||
self.log_states_dict = {} # date to full state dict
|
||||
|
||||
# Set up the graph
|
||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||
# Set up the graph (parallel analysts for speed when enabled)
|
||||
parallel = self.config.get("parallel_analysts", False)
|
||||
self.graph = self.graph_setup.setup_graph(selected_analysts, parallel=parallel)
|
||||
|
||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||
"""Get provider-specific kwargs for LLM client creation."""
|
||||
|
|
|
|||
Loading…
Reference in New Issue