diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..7a6a53e2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,110 @@ +"""conftest.py — mock out heavy third-party dependencies AND the internal +package __init__ modules that create deep import chains. This lets tests +import specific leaf modules (risk_manager.py, propagation.py, setup.py) +without triggering the full dependency tree, which requires Python 3.10+ +and numerous third-party packages. +""" + +import os +import sys +from unittest.mock import MagicMock + +_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +_PKG_ROOT = os.path.join(_PROJECT_ROOT, "tradingagents") + +# --------------------------------------------------------------------------- +# External third-party packages to mock +# --------------------------------------------------------------------------- +_EXTERNAL_PACKAGES = [ + "langchain_core", + "langchain_core.messages", + "langchain_core.tools", + "langchain_core.utils", + "langchain_core.utils.function_calling", + "langchain_openai", + "langchain_anthropic", + "langchain_google_genai", + "langchain_experimental", + "langchain_experimental.utilities", + "langgraph", + "langgraph.graph", + "langgraph.prebuilt", + "yfinance", + "pandas", + "backtrader", + "stockstats", + "rank_bm25", + "requests", + "parsel", + "redis", + "chainlit", + "questionary", + "typer", + "rich", + "rich.console", + "rich.panel", + "rich.table", + "rich.progress", + "tqdm", + "pytz", + "setuptools", +] + + +def _make_mock(name, path=None): + """Create a MagicMock that behaves like a module/package.""" + mock_mod = MagicMock() + mock_mod.__name__ = name + mock_mod.__file__ = f"" + mock_mod.__path__ = [path] if path else [] + mock_mod.__all__ = [] + mock_mod.__spec__ = None + mock_mod.__package__ = name + return mock_mod + + +# Install external mocks +for _pkg in _EXTERNAL_PACKAGES: + if _pkg not in sys.modules: + sys.modules[_pkg] = _make_mock(_pkg) + +# --------------------------------------------------------------------------- +# Internal packages: replace __init__.py-level imports with stubs that have +# real __path__ entries so importlib can still find submodule .py files. +# --------------------------------------------------------------------------- +_INTERNAL_PKG_DIRS = { + "tradingagents.graph": os.path.join(_PKG_ROOT, "graph"), + "tradingagents.graph.conditional_logic": None, # leaf, no subdir + "tradingagents.agents": os.path.join(_PKG_ROOT, "agents"), + "tradingagents.agents.utils": os.path.join(_PKG_ROOT, "agents", "utils"), + "tradingagents.agents.utils.agent_utils": None, + "tradingagents.agents.utils.memory": None, + "tradingagents.agents.managers": os.path.join(_PKG_ROOT, "agents", "managers"), + "tradingagents.dataflows": os.path.join(_PKG_ROOT, "dataflows"), + "tradingagents.dataflows.interface": None, +} + +for _pkg, _dir in _INTERNAL_PKG_DIRS.items(): + if _pkg not in sys.modules: + sys.modules[_pkg] = _make_mock(_pkg, _dir) + +# --------------------------------------------------------------------------- +# Now import the real leaf modules we want to test. +# --------------------------------------------------------------------------- +import importlib + +# agent_states — pure TypedDict definitions, needs typing_extensions + mocked langgraph +_agent_states = importlib.import_module("tradingagents.agents.utils.agent_states") +sys.modules["tradingagents.agents.utils.agent_states"] = _agent_states + +# propagation.py — imports agent_states (now real) +_propagation = importlib.import_module("tradingagents.graph.propagation") +sys.modules["tradingagents.graph.propagation"] = _propagation + +# risk_manager.py — standalone, only uses llm.invoke and memory +_risk_manager = importlib.import_module("tradingagents.agents.managers.risk_manager") +sys.modules["tradingagents.agents.managers.risk_manager"] = _risk_manager + +# setup.py — defines ChatModel type alias; imports langchain_* (mocked) and .conditional_logic (mocked) +_setup = importlib.import_module("tradingagents.graph.setup") +sys.modules["tradingagents.graph.setup"] = _setup diff --git a/tests/test_propagation.py b/tests/test_propagation.py new file mode 100644 index 00000000..66e2a74e --- /dev/null +++ b/tests/test_propagation.py @@ -0,0 +1,65 @@ +"""Tests for state initialization (tradingagents/graph/propagation.py). + +Verifies that create_initial_state produces complete InvestDebateState and +RiskDebateState dicts with all required fields (the incomplete-state bug fix). +""" + +from tradingagents.graph.propagation import Propagator + +INVEST_DEBATE_FIELDS = [ + "bull_history", + "bear_history", + "history", + "current_response", + "judge_decision", + "count", +] + +RISK_DEBATE_FIELDS = [ + "aggressive_history", + "conservative_history", + "neutral_history", + "history", + "latest_speaker", + "current_aggressive_response", + "current_conservative_response", + "current_neutral_response", + "judge_decision", + "count", +] + + +def _initial_state(): + return Propagator().create_initial_state("AAPL", "2025-01-01") + + +def test_initial_invest_debate_state_has_all_fields(): + state = _initial_state() + invest = state["investment_debate_state"] + for field in INVEST_DEBATE_FIELDS: + assert field in invest, f"InvestDebateState missing field: {field}" + + +def test_initial_risk_debate_state_has_all_fields(): + state = _initial_state() + risk = state["risk_debate_state"] + for field in RISK_DEBATE_FIELDS: + assert field in risk, f"RiskDebateState missing field: {field}" + + +def test_initial_state_fields_are_empty_defaults(): + state = _initial_state() + + invest = state["investment_debate_state"] + for field in INVEST_DEBATE_FIELDS: + if field == "count": + assert invest[field] == 0, f"InvestDebateState.{field} should be 0" + else: + assert invest[field] == "", f"InvestDebateState.{field} should be empty string" + + risk = state["risk_debate_state"] + for field in RISK_DEBATE_FIELDS: + if field == "count": + assert risk[field] == 0, f"RiskDebateState.{field} should be 0" + else: + assert risk[field] == "", f"RiskDebateState.{field} should be empty string" diff --git a/tests/test_risk_manager.py b/tests/test_risk_manager.py new file mode 100644 index 00000000..6b16244e --- /dev/null +++ b/tests/test_risk_manager.py @@ -0,0 +1,81 @@ +"""Tests for the risk manager node (tradingagents/agents/managers/risk_manager.py). + +Verifies the copy-paste bug fix: the risk manager must use fundamentals_report +(not a duplicate of news_report) when building its situation string. +""" + +from unittest.mock import MagicMock + +from tradingagents.agents.managers.risk_manager import create_risk_manager + + +def _make_state(news="news-text", fundamentals="fundamentals-text"): + """Return a minimal state dict suitable for risk_manager_node.""" + return { + "company_of_interest": "AAPL", + "risk_debate_state": { + "history": "debate history", + "aggressive_history": "", + "conservative_history": "", + "neutral_history": "", + "latest_speaker": "", + "current_aggressive_response": "", + "current_conservative_response": "", + "current_neutral_response": "", + "judge_decision": "", + "count": 0, + }, + "market_report": "market-text", + "news_report": news, + "fundamentals_report": fundamentals, + "sentiment_report": "sentiment-text", + "investment_plan": "plan-text", + } + + +def test_risk_manager_reads_fundamentals_report_not_news(): + """The curr_situation string must contain the fundamentals_report value, + not a second copy of news_report (the bug that was fixed at line 14).""" + llm = MagicMock() + llm.invoke.return_value = MagicMock(content="BUY") + + memory = MagicMock() + memory.get_memories.return_value = [] + + node = create_risk_manager(llm, memory) + state = _make_state(news="NEWS_UNIQUE", fundamentals="FUNDAMENTALS_UNIQUE") + node(state) + + # The LLM should have been called once; grab the prompt + llm.invoke.assert_called_once() + prompt = llm.invoke.call_args[0][0] + + # curr_situation is passed to memory.get_memories, not directly to LLM, + # but the fundamentals text appears in the prompt via the debate history context. + # More directly: memory.get_memories receives curr_situation as its first arg. + memory.get_memories.assert_called_once() + situation_arg = memory.get_memories.call_args[0][0] + + assert "FUNDAMENTALS_UNIQUE" in situation_arg, ( + "fundamentals_report should appear in the situation string" + ) + # Also verify news is present (it should be there once, not duplicated for fundamentals) + assert "NEWS_UNIQUE" in situation_arg, ( + "news_report should appear in the situation string" + ) + + +def test_risk_manager_returns_expected_state_keys(): + """The node must return a dict with 'risk_debate_state' and 'final_trade_decision'.""" + llm = MagicMock() + llm.invoke.return_value = MagicMock(content="HOLD") + + memory = MagicMock() + memory.get_memories.return_value = [] + + node = create_risk_manager(llm, memory) + result = node(_make_state()) + + assert "risk_debate_state" in result + assert "final_trade_decision" in result + assert result["final_trade_decision"] == "HOLD" diff --git a/tests/test_setup_types.py b/tests/test_setup_types.py new file mode 100644 index 00000000..0995c94d --- /dev/null +++ b/tests/test_setup_types.py @@ -0,0 +1,19 @@ +"""Tests for the ChatModel type alias (tradingagents/graph/setup.py). + +Verifies that ChatModel is a Union containing all three supported providers. +""" + +import typing + +from tradingagents.graph.setup import ChatModel +from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic +from langchain_google_genai import ChatGoogleGenerativeAI + + +def test_chat_model_union_includes_all_providers(): + """ChatModel should be a Union of ChatOpenAI, ChatAnthropic, and ChatGoogleGenerativeAI.""" + args = typing.get_args(ChatModel) + assert ChatOpenAI in args, "ChatModel should include ChatOpenAI" + assert ChatAnthropic in args, "ChatModel should include ChatAnthropic" + assert ChatGoogleGenerativeAI in args, "ChatModel should include ChatGoogleGenerativeAI"