85 lines
2.9 KiB
Python
85 lines
2.9 KiB
Python
import unittest
|
|
|
|
from tradingagents.graph.analyst_execution import (
|
|
AnalystWallTimeTracker,
|
|
build_analyst_execution_plan,
|
|
get_initial_analyst_node,
|
|
sync_analyst_tracker_from_chunk,
|
|
)
|
|
|
|
|
|
class AnalystExecutionPlanTests(unittest.TestCase):
|
|
def test_build_plan_preserves_selected_order(self):
|
|
plan = build_analyst_execution_plan(["news", "market"], concurrency_limit=2)
|
|
|
|
self.assertEqual([spec.key for spec in plan.specs], ["news", "market"])
|
|
self.assertEqual(plan.concurrency_limit, 2)
|
|
self.assertEqual(plan.specs[0].agent_node, "News Analyst")
|
|
self.assertEqual(plan.specs[0].tool_node, "tools_news")
|
|
self.assertEqual(plan.specs[0].clear_node, "Msg Clear News")
|
|
|
|
def test_rejects_unknown_analyst_keys(self):
|
|
with self.assertRaises(ValueError):
|
|
build_analyst_execution_plan(["market", "macro"])
|
|
|
|
def test_requires_positive_concurrency_limit(self):
|
|
with self.assertRaises(ValueError):
|
|
build_analyst_execution_plan(["market"], concurrency_limit=0)
|
|
|
|
def test_get_initial_analyst_node_uses_plan_metadata(self):
|
|
plan = build_analyst_execution_plan(["fundamentals", "news"])
|
|
|
|
self.assertEqual(
|
|
get_initial_analyst_node(plan),
|
|
"Fundamentals Analyst",
|
|
)
|
|
|
|
|
|
class AnalystWallTimeTrackerTests(unittest.TestCase):
|
|
def test_records_wall_time_when_analyst_completes(self):
|
|
plan = build_analyst_execution_plan(["market", "news"])
|
|
tracker = AnalystWallTimeTracker(plan)
|
|
|
|
tracker.mark_started("market", started_at=10.0)
|
|
tracker.mark_completed("market", completed_at=13.5)
|
|
|
|
self.assertEqual(tracker.get_wall_times(), {"market": 3.5})
|
|
|
|
def test_formats_summary_in_plan_order(self):
|
|
plan = build_analyst_execution_plan(["news", "market"])
|
|
tracker = AnalystWallTimeTracker(plan)
|
|
|
|
tracker.mark_started("market", started_at=20.0)
|
|
tracker.mark_completed("market", completed_at=22.25)
|
|
tracker.mark_started("news", started_at=10.0)
|
|
tracker.mark_completed("news", completed_at=14.0)
|
|
|
|
self.assertEqual(
|
|
tracker.format_summary(),
|
|
"Analyst wall time: News 4.00s | Market 2.25s",
|
|
)
|
|
|
|
def test_syncs_wall_time_from_sequential_chunks(self):
|
|
plan = build_analyst_execution_plan(["market", "news"])
|
|
tracker = AnalystWallTimeTracker(plan)
|
|
|
|
sync_analyst_tracker_from_chunk(tracker, {}, now=10.0)
|
|
self.assertEqual(tracker.get_wall_times(), {})
|
|
|
|
sync_analyst_tracker_from_chunk(
|
|
tracker,
|
|
{"market_report": "done"},
|
|
now=13.0,
|
|
)
|
|
self.assertEqual(tracker.get_wall_times(), {"market": 3.0})
|
|
|
|
sync_analyst_tracker_from_chunk(
|
|
tracker,
|
|
{"market_report": "done", "news_report": "done"},
|
|
now=18.0,
|
|
)
|
|
self.assertEqual(
|
|
tracker.get_wall_times(),
|
|
{"market": 3.0, "news": 5.0},
|
|
)
|