test: add unit tests for critical bug fixes
Verify the three fixes from 7477240: risk manager reading fundamentals_report (not duplicated news_report), complete state initialization for both debate states, and ChatModel Union type alias including all providers. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d6bb961e23
commit
c52b2ee3c2
|
|
@ -0,0 +1,110 @@
|
|||
"""conftest.py — mock out heavy third-party dependencies AND the internal
|
||||
package __init__ modules that create deep import chains. This lets tests
|
||||
import specific leaf modules (risk_manager.py, propagation.py, setup.py)
|
||||
without triggering the full dependency tree, which requires Python 3.10+
|
||||
and numerous third-party packages.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
_PKG_ROOT = os.path.join(_PROJECT_ROOT, "tradingagents")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# External third-party packages to mock
|
||||
# ---------------------------------------------------------------------------
|
||||
_EXTERNAL_PACKAGES = [
|
||||
"langchain_core",
|
||||
"langchain_core.messages",
|
||||
"langchain_core.tools",
|
||||
"langchain_core.utils",
|
||||
"langchain_core.utils.function_calling",
|
||||
"langchain_openai",
|
||||
"langchain_anthropic",
|
||||
"langchain_google_genai",
|
||||
"langchain_experimental",
|
||||
"langchain_experimental.utilities",
|
||||
"langgraph",
|
||||
"langgraph.graph",
|
||||
"langgraph.prebuilt",
|
||||
"yfinance",
|
||||
"pandas",
|
||||
"backtrader",
|
||||
"stockstats",
|
||||
"rank_bm25",
|
||||
"requests",
|
||||
"parsel",
|
||||
"redis",
|
||||
"chainlit",
|
||||
"questionary",
|
||||
"typer",
|
||||
"rich",
|
||||
"rich.console",
|
||||
"rich.panel",
|
||||
"rich.table",
|
||||
"rich.progress",
|
||||
"tqdm",
|
||||
"pytz",
|
||||
"setuptools",
|
||||
]
|
||||
|
||||
|
||||
def _make_mock(name, path=None):
|
||||
"""Create a MagicMock that behaves like a module/package."""
|
||||
mock_mod = MagicMock()
|
||||
mock_mod.__name__ = name
|
||||
mock_mod.__file__ = f"<mocked {name}>"
|
||||
mock_mod.__path__ = [path] if path else []
|
||||
mock_mod.__all__ = []
|
||||
mock_mod.__spec__ = None
|
||||
mock_mod.__package__ = name
|
||||
return mock_mod
|
||||
|
||||
|
||||
# Install external mocks
|
||||
for _pkg in _EXTERNAL_PACKAGES:
|
||||
if _pkg not in sys.modules:
|
||||
sys.modules[_pkg] = _make_mock(_pkg)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal packages: replace __init__.py-level imports with stubs that have
|
||||
# real __path__ entries so importlib can still find submodule .py files.
|
||||
# ---------------------------------------------------------------------------
|
||||
_INTERNAL_PKG_DIRS = {
|
||||
"tradingagents.graph": os.path.join(_PKG_ROOT, "graph"),
|
||||
"tradingagents.graph.conditional_logic": None, # leaf, no subdir
|
||||
"tradingagents.agents": os.path.join(_PKG_ROOT, "agents"),
|
||||
"tradingagents.agents.utils": os.path.join(_PKG_ROOT, "agents", "utils"),
|
||||
"tradingagents.agents.utils.agent_utils": None,
|
||||
"tradingagents.agents.utils.memory": None,
|
||||
"tradingagents.agents.managers": os.path.join(_PKG_ROOT, "agents", "managers"),
|
||||
"tradingagents.dataflows": os.path.join(_PKG_ROOT, "dataflows"),
|
||||
"tradingagents.dataflows.interface": None,
|
||||
}
|
||||
|
||||
for _pkg, _dir in _INTERNAL_PKG_DIRS.items():
|
||||
if _pkg not in sys.modules:
|
||||
sys.modules[_pkg] = _make_mock(_pkg, _dir)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Now import the real leaf modules we want to test.
|
||||
# ---------------------------------------------------------------------------
|
||||
import importlib
|
||||
|
||||
# agent_states — pure TypedDict definitions, needs typing_extensions + mocked langgraph
|
||||
_agent_states = importlib.import_module("tradingagents.agents.utils.agent_states")
|
||||
sys.modules["tradingagents.agents.utils.agent_states"] = _agent_states
|
||||
|
||||
# propagation.py — imports agent_states (now real)
|
||||
_propagation = importlib.import_module("tradingagents.graph.propagation")
|
||||
sys.modules["tradingagents.graph.propagation"] = _propagation
|
||||
|
||||
# risk_manager.py — standalone, only uses llm.invoke and memory
|
||||
_risk_manager = importlib.import_module("tradingagents.agents.managers.risk_manager")
|
||||
sys.modules["tradingagents.agents.managers.risk_manager"] = _risk_manager
|
||||
|
||||
# setup.py — defines ChatModel type alias; imports langchain_* (mocked) and .conditional_logic (mocked)
|
||||
_setup = importlib.import_module("tradingagents.graph.setup")
|
||||
sys.modules["tradingagents.graph.setup"] = _setup
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
"""Tests for state initialization (tradingagents/graph/propagation.py).
|
||||
|
||||
Verifies that create_initial_state produces complete InvestDebateState and
|
||||
RiskDebateState dicts with all required fields (the incomplete-state bug fix).
|
||||
"""
|
||||
|
||||
from tradingagents.graph.propagation import Propagator
|
||||
|
||||
INVEST_DEBATE_FIELDS = [
|
||||
"bull_history",
|
||||
"bear_history",
|
||||
"history",
|
||||
"current_response",
|
||||
"judge_decision",
|
||||
"count",
|
||||
]
|
||||
|
||||
RISK_DEBATE_FIELDS = [
|
||||
"aggressive_history",
|
||||
"conservative_history",
|
||||
"neutral_history",
|
||||
"history",
|
||||
"latest_speaker",
|
||||
"current_aggressive_response",
|
||||
"current_conservative_response",
|
||||
"current_neutral_response",
|
||||
"judge_decision",
|
||||
"count",
|
||||
]
|
||||
|
||||
|
||||
def _initial_state():
|
||||
return Propagator().create_initial_state("AAPL", "2025-01-01")
|
||||
|
||||
|
||||
def test_initial_invest_debate_state_has_all_fields():
|
||||
state = _initial_state()
|
||||
invest = state["investment_debate_state"]
|
||||
for field in INVEST_DEBATE_FIELDS:
|
||||
assert field in invest, f"InvestDebateState missing field: {field}"
|
||||
|
||||
|
||||
def test_initial_risk_debate_state_has_all_fields():
|
||||
state = _initial_state()
|
||||
risk = state["risk_debate_state"]
|
||||
for field in RISK_DEBATE_FIELDS:
|
||||
assert field in risk, f"RiskDebateState missing field: {field}"
|
||||
|
||||
|
||||
def test_initial_state_fields_are_empty_defaults():
|
||||
state = _initial_state()
|
||||
|
||||
invest = state["investment_debate_state"]
|
||||
for field in INVEST_DEBATE_FIELDS:
|
||||
if field == "count":
|
||||
assert invest[field] == 0, f"InvestDebateState.{field} should be 0"
|
||||
else:
|
||||
assert invest[field] == "", f"InvestDebateState.{field} should be empty string"
|
||||
|
||||
risk = state["risk_debate_state"]
|
||||
for field in RISK_DEBATE_FIELDS:
|
||||
if field == "count":
|
||||
assert risk[field] == 0, f"RiskDebateState.{field} should be 0"
|
||||
else:
|
||||
assert risk[field] == "", f"RiskDebateState.{field} should be empty string"
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
"""Tests for the risk manager node (tradingagents/agents/managers/risk_manager.py).
|
||||
|
||||
Verifies the copy-paste bug fix: the risk manager must use fundamentals_report
|
||||
(not a duplicate of news_report) when building its situation string.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from tradingagents.agents.managers.risk_manager import create_risk_manager
|
||||
|
||||
|
||||
def _make_state(news="news-text", fundamentals="fundamentals-text"):
|
||||
"""Return a minimal state dict suitable for risk_manager_node."""
|
||||
return {
|
||||
"company_of_interest": "AAPL",
|
||||
"risk_debate_state": {
|
||||
"history": "debate history",
|
||||
"aggressive_history": "",
|
||||
"conservative_history": "",
|
||||
"neutral_history": "",
|
||||
"latest_speaker": "",
|
||||
"current_aggressive_response": "",
|
||||
"current_conservative_response": "",
|
||||
"current_neutral_response": "",
|
||||
"judge_decision": "",
|
||||
"count": 0,
|
||||
},
|
||||
"market_report": "market-text",
|
||||
"news_report": news,
|
||||
"fundamentals_report": fundamentals,
|
||||
"sentiment_report": "sentiment-text",
|
||||
"investment_plan": "plan-text",
|
||||
}
|
||||
|
||||
|
||||
def test_risk_manager_reads_fundamentals_report_not_news():
|
||||
"""The curr_situation string must contain the fundamentals_report value,
|
||||
not a second copy of news_report (the bug that was fixed at line 14)."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(content="BUY")
|
||||
|
||||
memory = MagicMock()
|
||||
memory.get_memories.return_value = []
|
||||
|
||||
node = create_risk_manager(llm, memory)
|
||||
state = _make_state(news="NEWS_UNIQUE", fundamentals="FUNDAMENTALS_UNIQUE")
|
||||
node(state)
|
||||
|
||||
# The LLM should have been called once; grab the prompt
|
||||
llm.invoke.assert_called_once()
|
||||
prompt = llm.invoke.call_args[0][0]
|
||||
|
||||
# curr_situation is passed to memory.get_memories, not directly to LLM,
|
||||
# but the fundamentals text appears in the prompt via the debate history context.
|
||||
# More directly: memory.get_memories receives curr_situation as its first arg.
|
||||
memory.get_memories.assert_called_once()
|
||||
situation_arg = memory.get_memories.call_args[0][0]
|
||||
|
||||
assert "FUNDAMENTALS_UNIQUE" in situation_arg, (
|
||||
"fundamentals_report should appear in the situation string"
|
||||
)
|
||||
# Also verify news is present (it should be there once, not duplicated for fundamentals)
|
||||
assert "NEWS_UNIQUE" in situation_arg, (
|
||||
"news_report should appear in the situation string"
|
||||
)
|
||||
|
||||
|
||||
def test_risk_manager_returns_expected_state_keys():
|
||||
"""The node must return a dict with 'risk_debate_state' and 'final_trade_decision'."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(content="HOLD")
|
||||
|
||||
memory = MagicMock()
|
||||
memory.get_memories.return_value = []
|
||||
|
||||
node = create_risk_manager(llm, memory)
|
||||
result = node(_make_state())
|
||||
|
||||
assert "risk_debate_state" in result
|
||||
assert "final_trade_decision" in result
|
||||
assert result["final_trade_decision"] == "HOLD"
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""Tests for the ChatModel type alias (tradingagents/graph/setup.py).
|
||||
|
||||
Verifies that ChatModel is a Union containing all three supported providers.
|
||||
"""
|
||||
|
||||
import typing
|
||||
|
||||
from tradingagents.graph.setup import ChatModel
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
|
||||
def test_chat_model_union_includes_all_providers():
|
||||
"""ChatModel should be a Union of ChatOpenAI, ChatAnthropic, and ChatGoogleGenerativeAI."""
|
||||
args = typing.get_args(ChatModel)
|
||||
assert ChatOpenAI in args, "ChatModel should include ChatOpenAI"
|
||||
assert ChatAnthropic in args, "ChatModel should include ChatAnthropic"
|
||||
assert ChatGoogleGenerativeAI in args, "ChatModel should include ChatGoogleGenerativeAI"
|
||||
Loading…
Reference in New Issue