From 2d2c9e6d66bfe6a0393d41a93e3c9189baf9f7f7 Mon Sep 17 00:00:00 2001 From: CadeYu Date: Tue, 31 Mar 2026 09:55:33 +0800 Subject: [PATCH 1/2] add analyst execution planning and timing hooks --- cli/main.py | 26 ++++- tests/test_analyst_execution.py | 75 +++++++++++++ tradingagents/default_config.py | 1 + tradingagents/graph/analyst_execution.py | 132 +++++++++++++++++++++++ tradingagents/graph/setup.py | 75 +++++-------- tradingagents/graph/trading_graph.py | 1 + 6 files changed, 257 insertions(+), 53 deletions(-) create mode 100644 tests/test_analyst_execution.py create mode 100644 tradingagents/graph/analyst_execution.py diff --git a/cli/main.py b/cli/main.py index adda48fc..b4990c70 100644 --- a/cli/main.py +++ b/cli/main.py @@ -24,6 +24,11 @@ from rich.align import Align from rich.rule import Rule from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.graph.analyst_execution import ( + AnalystWallTimeTracker, + build_analyst_execution_plan, + sync_analyst_tracker_from_chunk, +) from tradingagents.default_config import DEFAULT_CONFIG from cli.models import AnalystType from cli.utils import * @@ -787,7 +792,7 @@ ANALYST_REPORT_MAP = { } -def update_analyst_statuses(message_buffer, chunk): +def update_analyst_statuses(message_buffer, chunk, wall_time_tracker=None): """Update all analyst statuses based on current report state. Logic: @@ -799,6 +804,9 @@ def update_analyst_statuses(message_buffer, chunk): selected = message_buffer.selected_analysts found_active = False + if wall_time_tracker is not None: + sync_analyst_tracker_from_chunk(wall_time_tracker, chunk) + for analyst_key in ANALYST_ORDER: if analyst_key not in selected: continue @@ -918,6 +926,11 @@ def run_analysis(): # Normalize analyst selection to predefined order (selection is a 'set', order is fixed) selected_set = {analyst.value for analyst in selections["analysts"]} selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set] + analyst_execution_plan = build_analyst_execution_plan( + selected_analyst_keys, + concurrency_limit=config["analyst_concurrency_limit"], + ) + analyst_wall_time_tracker = AnalystWallTimeTracker(analyst_execution_plan) # Initialize the graph with callbacks bound to LLMs graph = TradingAgentsGraph( @@ -999,8 +1012,9 @@ def run_analysis(): update_display(layout, stats_handler=stats_handler, start_time=start_time) # Update agent status to in_progress for the first analyst - first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst" + first_analyst = f"{selected_analyst_keys[0].capitalize()} Analyst" message_buffer.update_agent_status(first_analyst, "in_progress") + analyst_wall_time_tracker.mark_started(selected_analyst_keys[0]) update_display(layout, stats_handler=stats_handler, start_time=start_time) # Create spinner text @@ -1044,7 +1058,11 @@ def run_analysis(): message_buffer.add_tool_call(tool_call.name, tool_call.args) # Update analyst statuses based on report state (runs on every chunk) - update_analyst_statuses(message_buffer, chunk) + update_analyst_statuses( + message_buffer, + chunk, + wall_time_tracker=analyst_wall_time_tracker, + ) # Research Team - Handle Investment Debate State if chunk.get("investment_debate_state"): @@ -1133,6 +1151,7 @@ def run_analysis(): message_buffer.add_message( "System", f"Completed analysis for {selections['analysis_date']}" ) + message_buffer.add_message("System", analyst_wall_time_tracker.format_summary()) # Update final report sections for section in message_buffer.report_sections.keys(): @@ -1143,6 +1162,7 @@ def run_analysis(): # Post-analysis prompts (outside Live context for clean interaction) console.print("\n[bold cyan]Analysis Complete![/bold cyan]\n") + console.print(f"[dim]{analyst_wall_time_tracker.format_summary()}[/dim]") # Prompt to save report save_choice = typer.prompt("Save report?", default="Y").strip().upper() diff --git a/tests/test_analyst_execution.py b/tests/test_analyst_execution.py new file mode 100644 index 00000000..018af73e --- /dev/null +++ b/tests/test_analyst_execution.py @@ -0,0 +1,75 @@ +import unittest + +from tradingagents.graph.analyst_execution import ( + AnalystWallTimeTracker, + build_analyst_execution_plan, + sync_analyst_tracker_from_chunk, +) + + +class AnalystExecutionPlanTests(unittest.TestCase): + def test_build_plan_preserves_selected_order(self): + plan = build_analyst_execution_plan(["news", "market"], concurrency_limit=2) + + self.assertEqual([spec.key for spec in plan.specs], ["news", "market"]) + self.assertEqual(plan.concurrency_limit, 2) + self.assertEqual(plan.specs[0].agent_node, "News Analyst") + self.assertEqual(plan.specs[0].tool_node, "tools_news") + self.assertEqual(plan.specs[0].clear_node, "Msg Clear News") + + def test_rejects_unknown_analyst_keys(self): + with self.assertRaises(ValueError): + build_analyst_execution_plan(["market", "macro"]) + + def test_requires_positive_concurrency_limit(self): + with self.assertRaises(ValueError): + build_analyst_execution_plan(["market"], concurrency_limit=0) + + +class AnalystWallTimeTrackerTests(unittest.TestCase): + def test_records_wall_time_when_analyst_completes(self): + plan = build_analyst_execution_plan(["market", "news"]) + tracker = AnalystWallTimeTracker(plan) + + tracker.mark_started("market", started_at=10.0) + tracker.mark_completed("market", completed_at=13.5) + + self.assertEqual(tracker.get_wall_times(), {"market": 3.5}) + + def test_formats_summary_in_plan_order(self): + plan = build_analyst_execution_plan(["news", "market"]) + tracker = AnalystWallTimeTracker(plan) + + tracker.mark_started("market", started_at=20.0) + tracker.mark_completed("market", completed_at=22.25) + tracker.mark_started("news", started_at=10.0) + tracker.mark_completed("news", completed_at=14.0) + + self.assertEqual( + tracker.format_summary(), + "Analyst wall time: News 4.00s | Market 2.25s", + ) + + def test_syncs_wall_time_from_sequential_chunks(self): + plan = build_analyst_execution_plan(["market", "news"]) + tracker = AnalystWallTimeTracker(plan) + + sync_analyst_tracker_from_chunk(tracker, {}, now=10.0) + self.assertEqual(tracker.get_wall_times(), {}) + + sync_analyst_tracker_from_chunk( + tracker, + {"market_report": "done"}, + now=13.0, + ) + self.assertEqual(tracker.get_wall_times(), {"market": 3.0}) + + sync_analyst_tracker_from_chunk( + tracker, + {"market_report": "done", "news_report": "done"}, + now=18.0, + ) + self.assertEqual( + tracker.get_wall_times(), + {"market": 3.0, "news": 5.0}, + ) diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index ecf0dc29..4d8c1025 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -19,6 +19,7 @@ DEFAULT_CONFIG = { "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, "max_recur_limit": 100, + "analyst_concurrency_limit": 1, # Data vendor configuration # Category-level configuration (default for all tools in category) "data_vendors": { diff --git a/tradingagents/graph/analyst_execution.py b/tradingagents/graph/analyst_execution.py new file mode 100644 index 00000000..12dd8e19 --- /dev/null +++ b/tradingagents/graph/analyst_execution.py @@ -0,0 +1,132 @@ +from dataclasses import dataclass +from time import monotonic +from typing import Dict, Iterable, List, Optional + + +@dataclass(frozen=True) +class AnalystNodeSpec: + key: str + agent_node: str + clear_node: str + tool_node: str + report_key: str + + +@dataclass(frozen=True) +class AnalystExecutionPlan: + specs: List[AnalystNodeSpec] + concurrency_limit: int + + +ANALYST_NODE_SPECS: Dict[str, AnalystNodeSpec] = { + "market": AnalystNodeSpec( + key="market", + agent_node="Market Analyst", + clear_node="Msg Clear Market", + tool_node="tools_market", + report_key="market_report", + ), + "social": AnalystNodeSpec( + key="social", + agent_node="Social Analyst", + clear_node="Msg Clear Social", + tool_node="tools_social", + report_key="sentiment_report", + ), + "news": AnalystNodeSpec( + key="news", + agent_node="News Analyst", + clear_node="Msg Clear News", + tool_node="tools_news", + report_key="news_report", + ), + "fundamentals": AnalystNodeSpec( + key="fundamentals", + agent_node="Fundamentals Analyst", + clear_node="Msg Clear Fundamentals", + tool_node="tools_fundamentals", + report_key="fundamentals_report", + ), +} + + +def build_analyst_execution_plan( + selected_analysts: Iterable[str], + concurrency_limit: int = 1, +) -> AnalystExecutionPlan: + if concurrency_limit < 1: + raise ValueError("analyst concurrency limit must be >= 1") + + specs: List[AnalystNodeSpec] = [] + for analyst_key in selected_analysts: + spec = ANALYST_NODE_SPECS.get(analyst_key) + if spec is None: + raise ValueError(f"unknown analyst key: {analyst_key}") + specs.append(spec) + + if not specs: + raise ValueError("at least one analyst must be selected") + + return AnalystExecutionPlan(specs=specs, concurrency_limit=concurrency_limit) + + +class AnalystWallTimeTracker: + def __init__(self, plan: AnalystExecutionPlan): + self.plan = plan + self._started_at: Dict[str, float] = {} + self._wall_times: Dict[str, float] = {} + + def mark_started(self, analyst_key: str, started_at: Optional[float] = None) -> None: + if analyst_key not in ANALYST_NODE_SPECS: + raise ValueError(f"unknown analyst key: {analyst_key}") + self._started_at.setdefault(analyst_key, monotonic() if started_at is None else started_at) + + def mark_completed( + self, + analyst_key: str, + completed_at: Optional[float] = None, + ) -> None: + if analyst_key not in ANALYST_NODE_SPECS: + raise ValueError(f"unknown analyst key: {analyst_key}") + if analyst_key in self._wall_times: + return + started_at = self._started_at.get(analyst_key) + if started_at is None: + return + finished_at = monotonic() if completed_at is None else completed_at + self._wall_times[analyst_key] = max(0.0, finished_at - started_at) + + def get_wall_times(self) -> Dict[str, float]: + return dict(self._wall_times) + + def format_summary(self) -> str: + parts = [] + for spec in self.plan.specs: + duration = self._wall_times.get(spec.key) + if duration is not None: + label = spec.agent_node.removesuffix(" Analyst") + parts.append(f"{label} {duration:.2f}s") + if not parts: + return "Analyst wall time: pending" + return "Analyst wall time: " + " | ".join(parts) + + +def sync_analyst_tracker_from_chunk( + tracker: AnalystWallTimeTracker, + chunk: Dict[str, str], + now: Optional[float] = None, +) -> None: + current_time = monotonic() if now is None else now + active_found = False + + for spec in tracker.plan.specs: + has_report = bool(chunk.get(spec.report_key)) + + if has_report: + tracker.mark_started(spec.key, started_at=current_time) + tracker.mark_completed(spec.key, completed_at=current_time) + continue + + if not active_found: + tracker.mark_started(spec.key, started_at=current_time) + active_found = True diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index 772efe7f..11b7e56a 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -8,6 +8,7 @@ from langgraph.prebuilt import ToolNode from tradingagents.agents import * from tradingagents.agents.utils.agent_states import AgentState +from .analyst_execution import build_analyst_execution_plan from .conditional_logic import ConditionalLogic @@ -25,6 +26,7 @@ class GraphSetup: invest_judge_memory, risk_manager_memory, conditional_logic: ConditionalLogic, + analyst_concurrency_limit: int = 1, ): """Initialize with required components.""" self.quick_thinking_llm = quick_thinking_llm @@ -36,6 +38,7 @@ class GraphSetup: self.invest_judge_memory = invest_judge_memory self.risk_manager_memory = risk_manager_memory self.conditional_logic = conditional_logic + self.analyst_concurrency_limit = analyst_concurrency_limit def setup_graph( self, selected_analysts=["market", "social", "news", "fundamentals"] @@ -49,41 +52,17 @@ class GraphSetup: - "news": News analyst - "fundamentals": Fundamentals analyst """ - if len(selected_analysts) == 0: - raise ValueError("Trading Agents Graph Setup Error: no analysts selected!") + plan = build_analyst_execution_plan( + selected_analysts, + concurrency_limit=self.analyst_concurrency_limit, + ) - # Create analyst nodes - analyst_nodes = {} - delete_nodes = {} - tool_nodes = {} - - if "market" in selected_analysts: - analyst_nodes["market"] = create_market_analyst( - self.quick_thinking_llm - ) - delete_nodes["market"] = create_msg_delete() - tool_nodes["market"] = self.tool_nodes["market"] - - if "social" in selected_analysts: - analyst_nodes["social"] = create_social_media_analyst( - self.quick_thinking_llm - ) - delete_nodes["social"] = create_msg_delete() - tool_nodes["social"] = self.tool_nodes["social"] - - if "news" in selected_analysts: - analyst_nodes["news"] = create_news_analyst( - self.quick_thinking_llm - ) - delete_nodes["news"] = create_msg_delete() - tool_nodes["news"] = self.tool_nodes["news"] - - if "fundamentals" in selected_analysts: - analyst_nodes["fundamentals"] = create_fundamentals_analyst( - self.quick_thinking_llm - ) - delete_nodes["fundamentals"] = create_msg_delete() - tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] + analyst_factories = { + "market": lambda: create_market_analyst(self.quick_thinking_llm), + "social": lambda: create_social_media_analyst(self.quick_thinking_llm), + "news": lambda: create_news_analyst(self.quick_thinking_llm), + "fundamentals": lambda: create_fundamentals_analyst(self.quick_thinking_llm), + } # Create researcher and manager nodes bull_researcher_node = create_bull_researcher( @@ -109,12 +88,10 @@ class GraphSetup: workflow = StateGraph(AgentState) # Add analyst nodes to the graph - for analyst_type, node in analyst_nodes.items(): - workflow.add_node(f"{analyst_type.capitalize()} Analyst", node) - workflow.add_node( - f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type] - ) - workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type]) + for spec in plan.specs: + workflow.add_node(spec.agent_node, analyst_factories[spec.key]()) + workflow.add_node(spec.clear_node, create_msg_delete()) + workflow.add_node(spec.tool_node, self.tool_nodes[spec.key]) # Add other nodes workflow.add_node("Bull Researcher", bull_researcher_node) @@ -128,27 +105,25 @@ class GraphSetup: # Define edges # Start with the first analyst - first_analyst = selected_analysts[0] - workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst") + workflow.add_edge(START, plan.specs[0].agent_node) # Connect analysts in sequence - for i, analyst_type in enumerate(selected_analysts): - current_analyst = f"{analyst_type.capitalize()} Analyst" - current_tools = f"tools_{analyst_type}" - current_clear = f"Msg Clear {analyst_type.capitalize()}" + for i, spec in enumerate(plan.specs): + current_analyst = spec.agent_node + current_tools = spec.tool_node + current_clear = spec.clear_node # Add conditional edges for current analyst workflow.add_conditional_edges( current_analyst, - getattr(self.conditional_logic, f"should_continue_{analyst_type}"), + getattr(self.conditional_logic, f"should_continue_{spec.key}"), [current_tools, current_clear], ) workflow.add_edge(current_tools, current_analyst) # Connect to next analyst or to Bull Researcher if this is the last analyst - if i < len(selected_analysts) - 1: - next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst" - workflow.add_edge(current_clear, next_analyst) + if i < len(plan.specs) - 1: + workflow.add_edge(current_clear, plan.specs[i + 1].agent_node) else: workflow.add_edge(current_clear, "Bull Researcher") diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index c7ef0f98..0f1704cd 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -119,6 +119,7 @@ class TradingAgentsGraph: self.invest_judge_memory, self.risk_manager_memory, self.conditional_logic, + analyst_concurrency_limit=self.config.get("analyst_concurrency_limit", 1), ) self.propagator = Propagator() From f4519bcb84ab6a9ebf687ca83a3bf8e15d22ded4 Mon Sep 17 00:00:00 2001 From: CadeYu Date: Tue, 31 Mar 2026 10:09:57 +0800 Subject: [PATCH 2/2] use execution plan metadata for first analyst --- cli/main.py | 3 ++- tests/test_analyst_execution.py | 9 +++++++++ tradingagents/graph/analyst_execution.py | 4 ++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/cli/main.py b/cli/main.py index aeed6d02..7ac3af63 100644 --- a/cli/main.py +++ b/cli/main.py @@ -27,6 +27,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.graph.analyst_execution import ( AnalystWallTimeTracker, build_analyst_execution_plan, + get_initial_analyst_node, sync_analyst_tracker_from_chunk, ) from tradingagents.default_config import DEFAULT_CONFIG @@ -1044,7 +1045,7 @@ def run_analysis(): update_display(layout, stats_handler=stats_handler, start_time=start_time) # Update agent status to in_progress for the first analyst - first_analyst = f"{selected_analyst_keys[0].capitalize()} Analyst" + first_analyst = get_initial_analyst_node(analyst_execution_plan) message_buffer.update_agent_status(first_analyst, "in_progress") analyst_wall_time_tracker.mark_started(selected_analyst_keys[0]) update_display(layout, stats_handler=stats_handler, start_time=start_time) diff --git a/tests/test_analyst_execution.py b/tests/test_analyst_execution.py index 018af73e..7d4e851a 100644 --- a/tests/test_analyst_execution.py +++ b/tests/test_analyst_execution.py @@ -3,6 +3,7 @@ import unittest from tradingagents.graph.analyst_execution import ( AnalystWallTimeTracker, build_analyst_execution_plan, + get_initial_analyst_node, sync_analyst_tracker_from_chunk, ) @@ -25,6 +26,14 @@ class AnalystExecutionPlanTests(unittest.TestCase): with self.assertRaises(ValueError): build_analyst_execution_plan(["market"], concurrency_limit=0) + def test_get_initial_analyst_node_uses_plan_metadata(self): + plan = build_analyst_execution_plan(["fundamentals", "news"]) + + self.assertEqual( + get_initial_analyst_node(plan), + "Fundamentals Analyst", + ) + class AnalystWallTimeTrackerTests(unittest.TestCase): def test_records_wall_time_when_analyst_completes(self): diff --git a/tradingagents/graph/analyst_execution.py b/tradingagents/graph/analyst_execution.py index 12dd8e19..080668f1 100644 --- a/tradingagents/graph/analyst_execution.py +++ b/tradingagents/graph/analyst_execution.py @@ -70,6 +70,10 @@ def build_analyst_execution_plan( return AnalystExecutionPlan(specs=specs, concurrency_limit=concurrency_limit) +def get_initial_analyst_node(plan: AnalystExecutionPlan) -> str: + return plan.specs[0].agent_node + + class AnalystWallTimeTracker: def __init__(self, plan: AnalystExecutionPlan): self.plan = plan