diff --git a/cli/main.py b/cli/main.py index aeed6d02..7ac3af63 100644 --- a/cli/main.py +++ b/cli/main.py @@ -27,6 +27,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.graph.analyst_execution import ( AnalystWallTimeTracker, build_analyst_execution_plan, + get_initial_analyst_node, sync_analyst_tracker_from_chunk, ) from tradingagents.default_config import DEFAULT_CONFIG @@ -1044,7 +1045,7 @@ def run_analysis(): update_display(layout, stats_handler=stats_handler, start_time=start_time) # Update agent status to in_progress for the first analyst - first_analyst = f"{selected_analyst_keys[0].capitalize()} Analyst" + first_analyst = get_initial_analyst_node(analyst_execution_plan) message_buffer.update_agent_status(first_analyst, "in_progress") analyst_wall_time_tracker.mark_started(selected_analyst_keys[0]) update_display(layout, stats_handler=stats_handler, start_time=start_time) diff --git a/tests/test_analyst_execution.py b/tests/test_analyst_execution.py index 018af73e..7d4e851a 100644 --- a/tests/test_analyst_execution.py +++ b/tests/test_analyst_execution.py @@ -3,6 +3,7 @@ import unittest from tradingagents.graph.analyst_execution import ( AnalystWallTimeTracker, build_analyst_execution_plan, + get_initial_analyst_node, sync_analyst_tracker_from_chunk, ) @@ -25,6 +26,14 @@ class AnalystExecutionPlanTests(unittest.TestCase): with self.assertRaises(ValueError): build_analyst_execution_plan(["market"], concurrency_limit=0) + def test_get_initial_analyst_node_uses_plan_metadata(self): + plan = build_analyst_execution_plan(["fundamentals", "news"]) + + self.assertEqual( + get_initial_analyst_node(plan), + "Fundamentals Analyst", + ) + class AnalystWallTimeTrackerTests(unittest.TestCase): def test_records_wall_time_when_analyst_completes(self): diff --git a/tradingagents/graph/analyst_execution.py b/tradingagents/graph/analyst_execution.py index 12dd8e19..080668f1 100644 --- a/tradingagents/graph/analyst_execution.py +++ b/tradingagents/graph/analyst_execution.py @@ -70,6 +70,10 @@ def build_analyst_execution_plan( return AnalystExecutionPlan(specs=specs, concurrency_limit=concurrency_limit) +def get_initial_analyst_node(plan: AnalystExecutionPlan) -> str: + return plan.specs[0].agent_node + + class AnalystWallTimeTracker: def __init__(self, plan: AnalystExecutionPlan): self.plan = plan