128 lines
4.9 KiB
Python
128 lines
4.9 KiB
Python
import pytest
|
|
from unittest.mock import MagicMock, patch, PropertyMock
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
|
|
def _make_graph(chunks):
|
|
"""Helper: build a mocked TradingAgentsGraph whose graph.stream yields chunks."""
|
|
with patch.object(TradingAgentsGraph, '__init__', lambda self, *a, **kw: None):
|
|
ta = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
|
ta.graph = MagicMock()
|
|
ta.graph.stream.return_value = iter(chunks)
|
|
ta.graph.get_state.return_value = MagicMock(next=None) # no checkpoint to resume
|
|
ta.quick_thinking_llm = MagicMock()
|
|
ta.signal_processor = MagicMock()
|
|
ta.signal_processor.process_signal.return_value = "BUY"
|
|
ta.config = {"llm_provider": "openai", "deep_think_llm": "gpt-4",
|
|
"quick_think_llm": "gpt-4-mini", "max_debate_rounds": 1,
|
|
"max_risk_discuss_rounds": 1, "results_dir": "./results"}
|
|
ta._last_decision = None
|
|
ta.selected_analysts = ["market", "news", "fundamentals", "social"]
|
|
# propagator needed by stream_propagate for initial state and graph args
|
|
ta.propagator = MagicMock()
|
|
ta.propagator.create_initial_state.return_value = {
|
|
"messages": [], "company_of_interest": "NVDA", "trade_date": "2026-03-23",
|
|
"investment_debate_state": {
|
|
"bull_history": "", "bear_history": "", "history": "",
|
|
"current_response": "", "judge_decision": "", "count": 0
|
|
},
|
|
"risk_debate_state": {
|
|
"aggressive_history": "", "conservative_history": "", "neutral_history": "",
|
|
"history": "", "latest_speaker": "", "current_aggressive_response": "",
|
|
"current_conservative_response": "", "current_neutral_response": "",
|
|
"judge_decision": "", "count": 0
|
|
},
|
|
"market_report": "", "fundamentals_report": "",
|
|
"sentiment_report": "", "news_report": "",
|
|
}
|
|
ta.propagator.get_graph_args.return_value = {
|
|
"stream_mode": "updates",
|
|
"config": {"configurable": {"thread_id": "test-thread"}, "recursion_limit": 100},
|
|
}
|
|
# _log_state writes to disk — mock it out in all tests
|
|
ta._log_state = MagicMock()
|
|
return ta
|
|
|
|
|
|
def test_yields_known_node():
|
|
ta = _make_graph([
|
|
{"Market Analyst": {"market_report": "bullish outlook"}},
|
|
])
|
|
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
|
|
assert results == [("market_analyst", "bullish outlook")]
|
|
|
|
|
|
def test_skips_tool_nodes():
|
|
ta = _make_graph([
|
|
{"tools_market": {"messages": []}},
|
|
{"Market Analyst": {"market_report": "ok"}},
|
|
])
|
|
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
|
|
assert len(results) == 1
|
|
assert results[0][0] == "market_analyst"
|
|
|
|
|
|
def test_skips_msg_clear_nodes():
|
|
ta = _make_graph([
|
|
{"Msg Clear Market": {}},
|
|
{"News Analyst": {"news_report": "stable"}},
|
|
])
|
|
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
|
|
assert len(results) == 1
|
|
assert results[0][0] == "news_analyst"
|
|
|
|
|
|
def test_skips_unknown_nodes_with_warning(caplog):
|
|
import logging
|
|
ta = _make_graph([
|
|
{"Unknown Future Node": {"some_field": "value"}},
|
|
{"Trader": {"trader_investment_plan": "buy 100 shares"}},
|
|
])
|
|
with caplog.at_level(logging.WARNING):
|
|
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
|
|
assert len(results) == 1
|
|
assert results[0][0] == "trader"
|
|
assert any("Unknown Future Node" in r.message for r in caplog.records)
|
|
|
|
|
|
def test_last_decision_set_after_exhaustion():
|
|
ta = _make_graph([
|
|
{"Risk Judge": {"risk_debate_state": {"judge_decision": "SELL signal strong"}}},
|
|
])
|
|
# graph.get_state() is called post-loop to fetch the full final snapshot
|
|
ta.graph.get_state.return_value = MagicMock(
|
|
next=None,
|
|
values={"final_trade_decision": "strong SELL signal from risk team"}
|
|
)
|
|
|
|
list(ta.stream_propagate("NVDA", "2026-03-23"))
|
|
# signal_processor returns "BUY" from mock setup; _last_decision should be set
|
|
assert ta._last_decision == "BUY"
|
|
|
|
|
|
def test_bull_researcher_extracts_bull_history():
|
|
ta = _make_graph([
|
|
{"Bull Researcher": {"investment_debate_state": {
|
|
"bull_history": "bullish case round 1", "bear_history": "",
|
|
"history": "", "current_response": "", "judge_decision": "", "count": 1
|
|
}}},
|
|
])
|
|
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
|
|
assert results[0] == ("bull_researcher", "bullish case round 1")
|
|
|
|
|
|
def test_research_manager_extracts_investment_plan():
|
|
ta = _make_graph([
|
|
{"Research Manager": {"investment_plan": "Invest 20% in NVDA"}},
|
|
])
|
|
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
|
|
assert results[0] == ("research_manager", "Invest 20% in NVDA")
|
|
|
|
|
|
def test_missing_field_yields_empty_string():
|
|
ta = _make_graph([
|
|
{"Market Analyst": {}}, # no market_report key
|
|
])
|
|
results = list(ta.stream_propagate("NVDA", "2026-03-23"))
|
|
assert results[0] == ("market_analyst", "")
|