TradingAgents/tests/api/test_stream_propagate.py

128 lines
4.9 KiB
Python

import pytest
from unittest.mock import MagicMock, patch, PropertyMock
from tradingagents.graph.trading_graph import TradingAgentsGraph
def _make_graph(chunks):
"""Helper: build a mocked TradingAgentsGraph whose graph.stream yields chunks."""
with patch.object(TradingAgentsGraph, '__init__', lambda self, *a, **kw: None):
ta = TradingAgentsGraph.__new__(TradingAgentsGraph)
ta.graph = MagicMock()
ta.graph.stream.return_value = iter(chunks)
ta.graph.get_state.return_value = MagicMock(next=None) # no checkpoint to resume
ta.quick_thinking_llm = MagicMock()
ta.signal_processor = MagicMock()
ta.signal_processor.process_signal.return_value = "BUY"
ta.config = {"llm_provider": "openai", "deep_think_llm": "gpt-4",
"quick_think_llm": "gpt-4-mini", "max_debate_rounds": 1,
"max_risk_discuss_rounds": 1, "results_dir": "./results"}
ta._last_decision = None
ta.selected_analysts = ["market", "news", "fundamentals", "social"]
# propagator needed by stream_propagate for initial state and graph args
ta.propagator = MagicMock()
ta.propagator.create_initial_state.return_value = {
"messages": [], "company_of_interest": "NVDA", "trade_date": "2026-03-23",
"investment_debate_state": {
"bull_history": "", "bear_history": "", "history": "",
"current_response": "", "judge_decision": "", "count": 0
},
"risk_debate_state": {
"aggressive_history": "", "conservative_history": "", "neutral_history": "",
"history": "", "latest_speaker": "", "current_aggressive_response": "",
"current_conservative_response": "", "current_neutral_response": "",
"judge_decision": "", "count": 0
},
"market_report": "", "fundamentals_report": "",
"sentiment_report": "", "news_report": "",
}
ta.propagator.get_graph_args.return_value = {
"stream_mode": "updates",
"config": {"configurable": {"thread_id": "test-thread"}, "recursion_limit": 100},
}
# _log_state writes to disk — mock it out in all tests
ta._log_state = MagicMock()
return ta
def test_yields_known_node():
ta = _make_graph([
{"Market Analyst": {"market_report": "bullish outlook"}},
])
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
assert results == [("market_analyst", "bullish outlook")]
def test_skips_tool_nodes():
ta = _make_graph([
{"tools_market": {"messages": []}},
{"Market Analyst": {"market_report": "ok"}},
])
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
assert len(results) == 1
assert results[0][0] == "market_analyst"
def test_skips_msg_clear_nodes():
ta = _make_graph([
{"Msg Clear Market": {}},
{"News Analyst": {"news_report": "stable"}},
])
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
assert len(results) == 1
assert results[0][0] == "news_analyst"
def test_skips_unknown_nodes_with_warning(caplog):
import logging
ta = _make_graph([
{"Unknown Future Node": {"some_field": "value"}},
{"Trader": {"trader_investment_plan": "buy 100 shares"}},
])
with caplog.at_level(logging.WARNING):
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
assert len(results) == 1
assert results[0][0] == "trader"
assert any("Unknown Future Node" in r.message for r in caplog.records)
def test_last_decision_set_after_exhaustion():
ta = _make_graph([
{"Risk Judge": {"risk_debate_state": {"judge_decision": "SELL signal strong"}}},
])
# graph.get_state() is called post-loop to fetch the full final snapshot
ta.graph.get_state.return_value = MagicMock(
next=None,
values={"final_trade_decision": "strong SELL signal from risk team"}
)
list(ta.stream_propagate("NVDA", "2026-03-23"))
# signal_processor returns "BUY" from mock setup; _last_decision should be set
assert ta._last_decision == "BUY"
def test_bull_researcher_extracts_bull_history():
ta = _make_graph([
{"Bull Researcher": {"investment_debate_state": {
"bull_history": "bullish case round 1", "bear_history": "",
"history": "", "current_response": "", "judge_decision": "", "count": 1
}}},
])
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
assert results[0] == ("bull_researcher", "bullish case round 1")
def test_research_manager_extracts_investment_plan():
ta = _make_graph([
{"Research Manager": {"investment_plan": "Invest 20% in NVDA"}},
])
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
assert results[0] == ("research_manager", "Invest 20% in NVDA")
def test_missing_field_yields_empty_string():
ta = _make_graph([
{"Market Analyst": {}}, # no market_report key
])
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
assert results[0] == ("market_analyst", "")