feat: parallelize analyst agents for ~3x speedup
Run all 4 analysts (Market, Social, News, Fundamentals) concurrently using asyncio.gather instead of sequentially. Each analyst gets its own isolated message state and tool-calling loop. Cuts analyst phase from ~8-9 min to ~2-3 min (total analysis from ~11 min to ~4-5 min). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
f5519b9efe
commit
223879bc04
23
app.py
23
app.py
|
|
@ -53,6 +53,7 @@ def build_config():
|
||||||
"fundamental_data": "yfinance",
|
"fundamental_data": "yfinance",
|
||||||
"news_data": "yfinance",
|
"news_data": "yfinance",
|
||||||
}
|
}
|
||||||
|
config["parallel_analysts"] = True
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -122,6 +123,28 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||||
final_state = None
|
final_state = None
|
||||||
prev_statuses = {}
|
prev_statuses = {}
|
||||||
|
|
||||||
|
# Emit all analysts as "in_progress" immediately (they run in parallel)
|
||||||
|
analyst_name_map = {
|
||||||
|
"market": "Market Analyst",
|
||||||
|
"social": "Social Analyst",
|
||||||
|
"news": "News Analyst",
|
||||||
|
"fundamentals": "Fundamentals Analyst",
|
||||||
|
}
|
||||||
|
for analyst_type in selected_analysts:
|
||||||
|
agent_name = analyst_name_map[analyst_type]
|
||||||
|
buf.update_agent_status(agent_name, "in_progress")
|
||||||
|
st = get_stats_dict(stats_handler, buf, start_time)
|
||||||
|
evt = {
|
||||||
|
"type": "agent_update",
|
||||||
|
"agent": agent_name,
|
||||||
|
"stage": "analysts",
|
||||||
|
"status": "in_progress",
|
||||||
|
"stats": st,
|
||||||
|
}
|
||||||
|
state["events"].append(evt)
|
||||||
|
await q.put(evt)
|
||||||
|
prev_statuses[agent_name] = "in_progress"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in graph.graph.astream(init_state, **args):
|
async for chunk in graph.graph.astream(init_state, **args):
|
||||||
final_state = chunk
|
final_state = chunk
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
"""Parallel analyst execution for TradingAgents.
|
||||||
|
|
||||||
|
Runs all analyst agents (Market, Social, News, Fundamentals) concurrently
|
||||||
|
instead of sequentially, cutting the analyst phase from ~8-9 min to ~2-3 min.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||||
|
|
||||||
|
|
||||||
|
def create_parallel_analyst_node(analyst_fns, tool_nodes, selected_analysts):
|
||||||
|
"""Create a single LangGraph node that runs all analysts in parallel.
|
||||||
|
|
||||||
|
Each analyst gets its own isolated message state and runs its complete
|
||||||
|
tool-calling loop independently. Results are merged at the end.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
analyst_fns: dict mapping analyst type (e.g. "market") to node function
|
||||||
|
tool_nodes: dict mapping analyst type to ToolNode instance
|
||||||
|
selected_analysts: list of analyst types to run
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def parallel_analysts_node(state):
|
||||||
|
"""Run all analysts concurrently and merge their reports."""
|
||||||
|
|
||||||
|
async def run_single(analyst_type):
|
||||||
|
"""Run one analyst through its complete tool-calling loop."""
|
||||||
|
fn = analyst_fns[analyst_type]
|
||||||
|
tn = tool_nodes[analyst_type]
|
||||||
|
|
||||||
|
# Each analyst gets its own isolated message state
|
||||||
|
local_state = {
|
||||||
|
"messages": list(state["messages"]),
|
||||||
|
"trade_date": state["trade_date"],
|
||||||
|
"company_of_interest": state["company_of_interest"],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for _ in range(10): # safety limit on tool rounds
|
||||||
|
result = await asyncio.to_thread(fn, local_state)
|
||||||
|
ai_msg = result["messages"][0]
|
||||||
|
local_state["messages"] = local_state["messages"] + [ai_msg]
|
||||||
|
|
||||||
|
if not ai_msg.tool_calls:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Process tool calls
|
||||||
|
tool_result = await asyncio.to_thread(tn.invoke, local_state)
|
||||||
|
local_state["messages"] = (
|
||||||
|
local_state["messages"] + tool_result["messages"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return only report fields (not messages)
|
||||||
|
return {k: v for k, v in result.items() if k != "messages"}
|
||||||
|
|
||||||
|
# Run all analysts concurrently
|
||||||
|
tasks = [run_single(at) for at in selected_analysts if at in analyst_fns]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Merge all report fields
|
||||||
|
merged = {}
|
||||||
|
for r in results:
|
||||||
|
merged.update(r)
|
||||||
|
|
||||||
|
# Clear messages and add placeholder (same as Msg Clear nodes)
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
removal_ops = [
|
||||||
|
RemoveMessage(id=m.id)
|
||||||
|
for m in messages
|
||||||
|
if hasattr(m, "id") and m.id
|
||||||
|
]
|
||||||
|
merged["messages"] = removal_ops + [HumanMessage(content="Continue")]
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
return parallel_analysts_node
|
||||||
|
|
@ -9,6 +9,7 @@ from tradingagents.agents import *
|
||||||
from tradingagents.agents.utils.agent_states import AgentState
|
from tradingagents.agents.utils.agent_states import AgentState
|
||||||
|
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
|
from .parallel_analysts import create_parallel_analyst_node
|
||||||
|
|
||||||
|
|
||||||
class GraphSetup:
|
class GraphSetup:
|
||||||
|
|
@ -38,7 +39,8 @@ class GraphSetup:
|
||||||
self.conditional_logic = conditional_logic
|
self.conditional_logic = conditional_logic
|
||||||
|
|
||||||
def setup_graph(
|
def setup_graph(
|
||||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
self, selected_analysts=["market", "social", "news", "fundamentals"],
|
||||||
|
parallel=False,
|
||||||
):
|
):
|
||||||
"""Set up and compile the agent workflow graph.
|
"""Set up and compile the agent workflow graph.
|
||||||
|
|
||||||
|
|
@ -48,11 +50,12 @@ class GraphSetup:
|
||||||
- "social": Social media analyst
|
- "social": Social media analyst
|
||||||
- "news": News analyst
|
- "news": News analyst
|
||||||
- "fundamentals": Fundamentals analyst
|
- "fundamentals": Fundamentals analyst
|
||||||
|
parallel (bool): Run analysts in parallel instead of sequentially.
|
||||||
"""
|
"""
|
||||||
if len(selected_analysts) == 0:
|
if len(selected_analysts) == 0:
|
||||||
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
||||||
|
|
||||||
# Create analyst nodes
|
# Create analyst node functions and tool nodes
|
||||||
analyst_nodes = {}
|
analyst_nodes = {}
|
||||||
delete_nodes = {}
|
delete_nodes = {}
|
||||||
tool_nodes = {}
|
tool_nodes = {}
|
||||||
|
|
@ -108,13 +111,20 @@ class GraphSetup:
|
||||||
# Create workflow
|
# Create workflow
|
||||||
workflow = StateGraph(AgentState)
|
workflow = StateGraph(AgentState)
|
||||||
|
|
||||||
# Add analyst nodes to the graph
|
if parallel:
|
||||||
for analyst_type, node in analyst_nodes.items():
|
# Single node runs all analysts concurrently
|
||||||
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
|
parallel_node = create_parallel_analyst_node(
|
||||||
workflow.add_node(
|
analyst_nodes, tool_nodes, selected_analysts
|
||||||
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type]
|
|
||||||
)
|
)
|
||||||
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
|
workflow.add_node("Parallel Analysts", parallel_node)
|
||||||
|
else:
|
||||||
|
# Add analyst nodes individually for sequential execution
|
||||||
|
for analyst_type, node in analyst_nodes.items():
|
||||||
|
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
|
||||||
|
workflow.add_node(
|
||||||
|
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type]
|
||||||
|
)
|
||||||
|
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
|
||||||
|
|
||||||
# Add other nodes
|
# Add other nodes
|
||||||
workflow.add_node("Bull Researcher", bull_researcher_node)
|
workflow.add_node("Bull Researcher", bull_researcher_node)
|
||||||
|
|
@ -127,32 +137,34 @@ class GraphSetup:
|
||||||
workflow.add_node("Risk Judge", risk_manager_node)
|
workflow.add_node("Risk Judge", risk_manager_node)
|
||||||
|
|
||||||
# Define edges
|
# Define edges
|
||||||
# Start with the first analyst
|
if parallel:
|
||||||
first_analyst = selected_analysts[0]
|
# Parallel: START → Parallel Analysts → Bull Researcher
|
||||||
workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst")
|
workflow.add_edge(START, "Parallel Analysts")
|
||||||
|
workflow.add_edge("Parallel Analysts", "Bull Researcher")
|
||||||
|
else:
|
||||||
|
# Sequential: START → Analyst 1 → ... → Analyst N → Bull Researcher
|
||||||
|
first_analyst = selected_analysts[0]
|
||||||
|
workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst")
|
||||||
|
|
||||||
# Connect analysts in sequence
|
for i, analyst_type in enumerate(selected_analysts):
|
||||||
for i, analyst_type in enumerate(selected_analysts):
|
current_analyst = f"{analyst_type.capitalize()} Analyst"
|
||||||
current_analyst = f"{analyst_type.capitalize()} Analyst"
|
current_tools = f"tools_{analyst_type}"
|
||||||
current_tools = f"tools_{analyst_type}"
|
current_clear = f"Msg Clear {analyst_type.capitalize()}"
|
||||||
current_clear = f"Msg Clear {analyst_type.capitalize()}"
|
|
||||||
|
|
||||||
# Add conditional edges for current analyst
|
workflow.add_conditional_edges(
|
||||||
workflow.add_conditional_edges(
|
current_analyst,
|
||||||
current_analyst,
|
getattr(self.conditional_logic, f"should_continue_{analyst_type}"),
|
||||||
getattr(self.conditional_logic, f"should_continue_{analyst_type}"),
|
[current_tools, current_clear],
|
||||||
[current_tools, current_clear],
|
)
|
||||||
)
|
workflow.add_edge(current_tools, current_analyst)
|
||||||
workflow.add_edge(current_tools, current_analyst)
|
|
||||||
|
|
||||||
# Connect to next analyst or to Bull Researcher if this is the last analyst
|
if i < len(selected_analysts) - 1:
|
||||||
if i < len(selected_analysts) - 1:
|
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
|
||||||
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
|
workflow.add_edge(current_clear, next_analyst)
|
||||||
workflow.add_edge(current_clear, next_analyst)
|
else:
|
||||||
else:
|
workflow.add_edge(current_clear, "Bull Researcher")
|
||||||
workflow.add_edge(current_clear, "Bull Researcher")
|
|
||||||
|
|
||||||
# Add remaining edges
|
# Add remaining edges (same for both modes)
|
||||||
workflow.add_conditional_edges(
|
workflow.add_conditional_edges(
|
||||||
"Bull Researcher",
|
"Bull Researcher",
|
||||||
self.conditional_logic.should_continue_debate,
|
self.conditional_logic.should_continue_debate,
|
||||||
|
|
|
||||||
|
|
@ -127,8 +127,9 @@ class TradingAgentsGraph:
|
||||||
self.ticker = None
|
self.ticker = None
|
||||||
self.log_states_dict = {} # date to full state dict
|
self.log_states_dict = {} # date to full state dict
|
||||||
|
|
||||||
# Set up the graph
|
# Set up the graph (parallel analysts for speed when enabled)
|
||||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
parallel = self.config.get("parallel_analysts", False)
|
||||||
|
self.graph = self.graph_setup.setup_graph(selected_analysts, parallel=parallel)
|
||||||
|
|
||||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||||
"""Get provider-specific kwargs for LLM client creation."""
|
"""Get provider-specific kwargs for LLM client creation."""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue