diff --git a/app.py b/app.py index 84d44d72..d9d3ac28 100644 --- a/app.py +++ b/app.py @@ -53,6 +53,7 @@ def build_config(): "fundamental_data": "yfinance", "news_data": "yfinance", } + config["parallel_analysts"] = True return config @@ -122,6 +123,28 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str): final_state = None prev_statuses = {} + # Emit all analysts as "in_progress" immediately (they run in parallel) + analyst_name_map = { + "market": "Market Analyst", + "social": "Social Analyst", + "news": "News Analyst", + "fundamentals": "Fundamentals Analyst", + } + for analyst_type in selected_analysts: + agent_name = analyst_name_map[analyst_type] + buf.update_agent_status(agent_name, "in_progress") + st = get_stats_dict(stats_handler, buf, start_time) + evt = { + "type": "agent_update", + "agent": agent_name, + "stage": "analysts", + "status": "in_progress", + "stats": st, + } + state["events"].append(evt) + await q.put(evt) + prev_statuses[agent_name] = "in_progress" + try: async for chunk in graph.graph.astream(init_state, **args): final_state = chunk diff --git a/tradingagents/graph/parallel_analysts.py b/tradingagents/graph/parallel_analysts.py new file mode 100644 index 00000000..1fddd8a6 --- /dev/null +++ b/tradingagents/graph/parallel_analysts.py @@ -0,0 +1,76 @@ +"""Parallel analyst execution for TradingAgents. + +Runs all analyst agents (Market, Social, News, Fundamentals) concurrently +instead of sequentially, cutting the analyst phase from ~8-9 min to ~2-3 min. +""" + +import asyncio +from langchain_core.messages import HumanMessage, RemoveMessage + + +def create_parallel_analyst_node(analyst_fns, tool_nodes, selected_analysts): + """Create a single LangGraph node that runs all analysts in parallel. + + Each analyst gets its own isolated message state and runs its complete + tool-calling loop independently. Results are merged at the end. + + Args: + analyst_fns: dict mapping analyst type (e.g. "market") to node function + tool_nodes: dict mapping analyst type to ToolNode instance + selected_analysts: list of analyst types to run + """ + + async def parallel_analysts_node(state): + """Run all analysts concurrently and merge their reports.""" + + async def run_single(analyst_type): + """Run one analyst through its complete tool-calling loop.""" + fn = analyst_fns[analyst_type] + tn = tool_nodes[analyst_type] + + # Each analyst gets its own isolated message state + local_state = { + "messages": list(state["messages"]), + "trade_date": state["trade_date"], + "company_of_interest": state["company_of_interest"], + } + + result = {} + for _ in range(10): # safety limit on tool rounds + result = await asyncio.to_thread(fn, local_state) + ai_msg = result["messages"][0] + local_state["messages"] = local_state["messages"] + [ai_msg] + + if not ai_msg.tool_calls: + break + + # Process tool calls + tool_result = await asyncio.to_thread(tn.invoke, local_state) + local_state["messages"] = ( + local_state["messages"] + tool_result["messages"] + ) + + # Return only report fields (not messages) + return {k: v for k, v in result.items() if k != "messages"} + + # Run all analysts concurrently + tasks = [run_single(at) for at in selected_analysts if at in analyst_fns] + results = await asyncio.gather(*tasks) + + # Merge all report fields + merged = {} + for r in results: + merged.update(r) + + # Clear messages and add placeholder (same as Msg Clear nodes) + messages = state.get("messages", []) + removal_ops = [ + RemoveMessage(id=m.id) + for m in messages + if hasattr(m, "id") and m.id + ] + merged["messages"] = removal_ops + [HumanMessage(content="Continue")] + + return merged + + return parallel_analysts_node diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index 772efe7f..41c69140 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -9,6 +9,7 @@ from tradingagents.agents import * from tradingagents.agents.utils.agent_states import AgentState from .conditional_logic import ConditionalLogic +from .parallel_analysts import create_parallel_analyst_node class GraphSetup: @@ -38,7 +39,8 @@ class GraphSetup: self.conditional_logic = conditional_logic def setup_graph( - self, selected_analysts=["market", "social", "news", "fundamentals"] + self, selected_analysts=["market", "social", "news", "fundamentals"], + parallel=False, ): """Set up and compile the agent workflow graph. @@ -48,11 +50,12 @@ class GraphSetup: - "social": Social media analyst - "news": News analyst - "fundamentals": Fundamentals analyst + parallel (bool): Run analysts in parallel instead of sequentially. """ if len(selected_analysts) == 0: raise ValueError("Trading Agents Graph Setup Error: no analysts selected!") - # Create analyst nodes + # Create analyst node functions and tool nodes analyst_nodes = {} delete_nodes = {} tool_nodes = {} @@ -108,13 +111,20 @@ class GraphSetup: # Create workflow workflow = StateGraph(AgentState) - # Add analyst nodes to the graph - for analyst_type, node in analyst_nodes.items(): - workflow.add_node(f"{analyst_type.capitalize()} Analyst", node) - workflow.add_node( - f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type] + if parallel: + # Single node runs all analysts concurrently + parallel_node = create_parallel_analyst_node( + analyst_nodes, tool_nodes, selected_analysts ) - workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type]) + workflow.add_node("Parallel Analysts", parallel_node) + else: + # Add analyst nodes individually for sequential execution + for analyst_type, node in analyst_nodes.items(): + workflow.add_node(f"{analyst_type.capitalize()} Analyst", node) + workflow.add_node( + f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type] + ) + workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type]) # Add other nodes workflow.add_node("Bull Researcher", bull_researcher_node) @@ -127,32 +137,34 @@ class GraphSetup: workflow.add_node("Risk Judge", risk_manager_node) # Define edges - # Start with the first analyst - first_analyst = selected_analysts[0] - workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst") + if parallel: + # Parallel: START → Parallel Analysts → Bull Researcher + workflow.add_edge(START, "Parallel Analysts") + workflow.add_edge("Parallel Analysts", "Bull Researcher") + else: + # Sequential: START → Analyst 1 → ... → Analyst N → Bull Researcher + first_analyst = selected_analysts[0] + workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst") - # Connect analysts in sequence - for i, analyst_type in enumerate(selected_analysts): - current_analyst = f"{analyst_type.capitalize()} Analyst" - current_tools = f"tools_{analyst_type}" - current_clear = f"Msg Clear {analyst_type.capitalize()}" + for i, analyst_type in enumerate(selected_analysts): + current_analyst = f"{analyst_type.capitalize()} Analyst" + current_tools = f"tools_{analyst_type}" + current_clear = f"Msg Clear {analyst_type.capitalize()}" - # Add conditional edges for current analyst - workflow.add_conditional_edges( - current_analyst, - getattr(self.conditional_logic, f"should_continue_{analyst_type}"), - [current_tools, current_clear], - ) - workflow.add_edge(current_tools, current_analyst) + workflow.add_conditional_edges( + current_analyst, + getattr(self.conditional_logic, f"should_continue_{analyst_type}"), + [current_tools, current_clear], + ) + workflow.add_edge(current_tools, current_analyst) - # Connect to next analyst or to Bull Researcher if this is the last analyst - if i < len(selected_analysts) - 1: - next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst" - workflow.add_edge(current_clear, next_analyst) - else: - workflow.add_edge(current_clear, "Bull Researcher") + if i < len(selected_analysts) - 1: + next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst" + workflow.add_edge(current_clear, next_analyst) + else: + workflow.add_edge(current_clear, "Bull Researcher") - # Add remaining edges + # Add remaining edges (same for both modes) workflow.add_conditional_edges( "Bull Researcher", self.conditional_logic.should_continue_debate, diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 44ecca0c..bd33c4b9 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -127,8 +127,9 @@ class TradingAgentsGraph: self.ticker = None self.log_states_dict = {} # date to full state dict - # Set up the graph - self.graph = self.graph_setup.setup_graph(selected_analysts) + # Set up the graph (parallel analysts for speed when enabled) + parallel = self.config.get("parallel_analysts", False) + self.graph = self.graph_setup.setup_graph(selected_analysts, parallel=parallel) def _get_provider_kwargs(self) -> Dict[str, Any]: """Get provider-specific kwargs for LLM client creation."""