TradingAgents/tests/api/test_run_service.py

119 lines
4.3 KiB
Python

import pytest
from collections import defaultdict
from unittest.mock import patch, MagicMock
from api.services.run_service import RunService
from api.store.runs_store import RunsStore
from api.models.run import RunConfig, RunStatus
@pytest.fixture
def store():
return RunsStore()
@pytest.fixture
def service(store):
return RunService(store)
def _mock_graph(stream_yields, decision="BUY"):
"""Return a mock TradingAgentsGraph instance for stream_propagate tests."""
mock = MagicMock()
mock.stream_propagate.return_value = iter(stream_yields)
mock._last_decision = decision
return mock
def test_emits_agent_start_and_complete_per_step(service, store):
config = RunConfig(ticker="NVDA", date="2026-03-23")
run = store.create(config)
yields = [
("market_analyst", "bullish"),
("news_analyst", "stable"),
]
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
MockGraph.return_value = _mock_graph(yields)
events = list(service.stream_events(run.id))
starts = [e for e in events if e["event"] == "agent:start"]
completes = [e for e in events if e["event"] == "agent:complete"]
assert len(starts) == 2
assert len(completes) == 2
assert starts[0]["data"]["step"] == "market_analyst"
assert starts[0]["data"]["turn"] == 0
assert completes[0]["data"]["report"] == "bullish"
def test_turn_increments_for_repeat_steps(service, store):
config = RunConfig(ticker="NVDA", date="2026-03-23")
run = store.create(config)
yields = [
("bull_researcher", "bull round 1"),
("bear_researcher", "bear round 1"),
("bull_researcher", "bull round 2"),
]
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
MockGraph.return_value = _mock_graph(yields)
events = list(service.stream_events(run.id))
bull_completes = [e for e in events
if e["event"] == "agent:complete" and e["data"]["step"] == "bull_researcher"]
assert len(bull_completes) == 2
assert bull_completes[0]["data"]["turn"] == 0
assert bull_completes[1]["data"]["turn"] == 1
def test_run_complete_emitted_with_decision(service, store):
config = RunConfig(ticker="NVDA", date="2026-03-23")
run = store.create(config)
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
MockGraph.return_value = _mock_graph([("trader", "buy signal")], decision="SELL")
events = list(service.stream_events(run.id))
complete = next(e for e in events if e["event"] == "run:complete")
assert complete["data"]["decision"] == "SELL"
def test_store_status_set_to_complete(service, store):
config = RunConfig(ticker="NVDA", date="2026-03-23")
run = store.create(config)
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
MockGraph.return_value = _mock_graph([])
list(service.stream_events(run.id))
assert store.get(run.id).status == RunStatus.COMPLETE
def test_error_during_stream_emits_run_error(service, store):
config = RunConfig(ticker="NVDA", date="2026-03-23")
run = store.create(config)
def bad_stream(*args, **kwargs):
yield ("market_analyst", "ok")
raise RuntimeError("LLM network error")
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
mock = MagicMock()
mock.stream_propagate.side_effect = bad_stream
MockGraph.return_value = mock
events = list(service.stream_events(run.id))
error_events = [e for e in events if e["event"] == "run:error"]
assert len(error_events) == 1
assert "LLM network error" in error_events[0]["data"]["message"]
# Also verify the store recorded the error
assert store.get(run.id).error is not None
assert "LLM network error" in store.get(run.id).error
def test_selected_analysts_passed_to_graph(service, store):
config = RunConfig(ticker="NVDA", date="2026-03-23",
enabled_analysts=["market", "news"])
run = store.create(config)
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
MockGraph.return_value = _mock_graph([])
list(service.stream_events(run.id))
call_kwargs = MockGraph.call_args.kwargs
assert call_kwargs.get("selected_analysts") == ["market", "news"]