"""Tests for Issue #17: Analyst integration into graph/setup.py workflow. This module tests the integration of new analysts (momentum, macro, correlation) into the TradingAgents graph workflow. """ import pytest from unittest.mock import MagicMock, patch from typing import get_type_hints import sys # Check if langchain dependencies are available try: import langchain_core LANGCHAIN_AVAILABLE = True except ImportError: LANGCHAIN_AVAILABLE = False # Skip all tests if langchain not available pytestmark = pytest.mark.skipif( not LANGCHAIN_AVAILABLE, reason="langchain_core not installed" ) class TestAgentStateReports: """Test that AgentState has the new report fields.""" def test_agent_state_has_momentum_report(self): """AgentState should have momentum_report field.""" from tradingagents.agents.utils.agent_states import AgentState hints = get_type_hints(AgentState) assert "momentum_report" in hints def test_agent_state_has_macro_report(self): """AgentState should have macro_report field.""" from tradingagents.agents.utils.agent_states import AgentState hints = get_type_hints(AgentState) assert "macro_report" in hints def test_agent_state_has_correlation_report(self): """AgentState should have correlation_report field.""" from tradingagents.agents.utils.agent_states import AgentState hints = get_type_hints(AgentState) assert "correlation_report" in hints class TestConditionalLogicMethods: """Test that ConditionalLogic has methods for new analysts.""" def test_should_continue_momentum_exists(self): """ConditionalLogic should have should_continue_momentum method.""" from tradingagents.graph.conditional_logic import ConditionalLogic cl = ConditionalLogic() assert hasattr(cl, "should_continue_momentum") assert callable(cl.should_continue_momentum) def test_should_continue_macro_exists(self): """ConditionalLogic should have should_continue_macro method.""" from tradingagents.graph.conditional_logic import ConditionalLogic cl = ConditionalLogic() assert hasattr(cl, "should_continue_macro") assert callable(cl.should_continue_macro) def test_should_continue_correlation_exists(self): """ConditionalLogic should have should_continue_correlation method.""" from tradingagents.graph.conditional_logic import ConditionalLogic cl = ConditionalLogic() assert hasattr(cl, "should_continue_correlation") assert callable(cl.should_continue_correlation) def test_momentum_conditional_returns_tools(self): """should_continue_momentum should return 'tools_momentum' when tool_calls exist.""" from tradingagents.graph.conditional_logic import ConditionalLogic cl = ConditionalLogic() mock_message = MagicMock() mock_message.tool_calls = [{"name": "test"}] state = {"messages": [mock_message]} result = cl.should_continue_momentum(state) assert result == "tools_momentum" def test_momentum_conditional_returns_clear(self): """should_continue_momentum should return 'Msg Clear Momentum' when no tool_calls.""" from tradingagents.graph.conditional_logic import ConditionalLogic cl = ConditionalLogic() mock_message = MagicMock() mock_message.tool_calls = [] state = {"messages": [mock_message]} result = cl.should_continue_momentum(state) assert result == "Msg Clear Momentum" def test_macro_conditional_returns_tools(self): """should_continue_macro should return 'tools_macro' when tool_calls exist.""" from tradingagents.graph.conditional_logic import ConditionalLogic cl = ConditionalLogic() mock_message = MagicMock() mock_message.tool_calls = [{"name": "test"}] state = {"messages": [mock_message]} result = cl.should_continue_macro(state) assert result == "tools_macro" def test_macro_conditional_returns_clear(self): """should_continue_macro should return 'Msg Clear Macro' when no tool_calls.""" from tradingagents.graph.conditional_logic import ConditionalLogic cl = ConditionalLogic() mock_message = MagicMock() mock_message.tool_calls = [] state = {"messages": [mock_message]} result = cl.should_continue_macro(state) assert result == "Msg Clear Macro" def test_correlation_conditional_returns_tools(self): """should_continue_correlation should return 'tools_correlation' when tool_calls exist.""" from tradingagents.graph.conditional_logic import ConditionalLogic cl = ConditionalLogic() mock_message = MagicMock() mock_message.tool_calls = [{"name": "test"}] state = {"messages": [mock_message]} result = cl.should_continue_correlation(state) assert result == "tools_correlation" def test_correlation_conditional_returns_clear(self): """should_continue_correlation should return 'Msg Clear Correlation' when no tool_calls.""" from tradingagents.graph.conditional_logic import ConditionalLogic cl = ConditionalLogic() mock_message = MagicMock() mock_message.tool_calls = [] state = {"messages": [mock_message]} result = cl.should_continue_correlation(state) assert result == "Msg Clear Correlation" class TestAgentImports: """Test that new analysts are properly exported from agents module.""" def test_create_momentum_analyst_import(self): """create_momentum_analyst should be importable from agents.""" from tradingagents.agents import create_momentum_analyst assert callable(create_momentum_analyst) def test_create_macro_analyst_import(self): """create_macro_analyst should be importable from agents.""" from tradingagents.agents import create_macro_analyst assert callable(create_macro_analyst) def test_create_correlation_analyst_import(self): """create_correlation_analyst should be importable from agents.""" from tradingagents.agents import create_correlation_analyst assert callable(create_correlation_analyst) def test_create_position_sizing_manager_import(self): """create_position_sizing_manager should be importable from agents.""" from tradingagents.agents import create_position_sizing_manager assert callable(create_position_sizing_manager) class TestAnalystsModuleExports: """Test analysts submodule exports.""" def test_analysts_module_exports_momentum(self): """analysts module should export create_momentum_analyst.""" from tradingagents.agents.analysts import create_momentum_analyst assert callable(create_momentum_analyst) def test_analysts_module_exports_macro(self): """analysts module should export create_macro_analyst.""" from tradingagents.agents.analysts import create_macro_analyst assert callable(create_macro_analyst) def test_analysts_module_exports_correlation(self): """analysts module should export create_correlation_analyst.""" from tradingagents.agents.analysts import create_correlation_analyst assert callable(create_correlation_analyst) def test_analysts_module_all_exports(self): """analysts __all__ should include all analyst creators.""" from tradingagents.agents import analysts assert "create_momentum_analyst" in analysts.__all__ assert "create_macro_analyst" in analysts.__all__ assert "create_correlation_analyst" in analysts.__all__ class TestManagersModuleExports: """Test managers submodule exports.""" def test_managers_module_exports_position_sizing(self): """managers module should export create_position_sizing_manager.""" from tradingagents.agents.managers import create_position_sizing_manager assert callable(create_position_sizing_manager) def test_managers_module_all_exports(self): """managers __all__ should include position_sizing_manager.""" from tradingagents.agents import managers assert "create_position_sizing_manager" in managers.__all__ class TestToolImports: """Test that tool functions are importable from analyst modules.""" def test_momentum_tools_importable(self): """Momentum analyst tools should be importable.""" from tradingagents.agents.analysts.momentum_analyst import ( get_multi_timeframe_momentum, get_adx_analysis, get_momentum_divergence, ) assert callable(get_multi_timeframe_momentum) assert callable(get_adx_analysis) assert callable(get_momentum_divergence) def test_macro_tools_importable(self): """Macro analyst tools should be importable.""" from tradingagents.agents.analysts.macro_analyst import ( get_economic_regime_analysis, get_yield_curve_analysis, get_monetary_policy_analysis, get_inflation_regime_analysis, ) assert callable(get_economic_regime_analysis) assert callable(get_yield_curve_analysis) assert callable(get_monetary_policy_analysis) assert callable(get_inflation_regime_analysis) def test_correlation_tools_importable(self): """Correlation analyst tools should be importable.""" from tradingagents.agents.analysts.correlation_analyst import ( get_cross_asset_correlation_analysis, get_sector_rotation_analysis, get_correlation_matrix, get_rolling_correlation_trend, ) assert callable(get_cross_asset_correlation_analysis) assert callable(get_sector_rotation_analysis) assert callable(get_correlation_matrix) assert callable(get_rolling_correlation_trend) class TestTradingGraphToolNodes: """Test that trading_graph.py has tool nodes for new analysts.""" @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.set_config") @patch("os.makedirs") def test_tool_nodes_include_momentum(self, mock_makedirs, mock_set_config, mock_llm): """_create_tool_nodes should include momentum tools.""" from tradingagents.graph.trading_graph import TradingAgentsGraph mock_llm.return_value = MagicMock() with patch.object(TradingAgentsGraph, "__init__", lambda self: None): graph = TradingAgentsGraph.__new__(TradingAgentsGraph) tool_nodes = graph._create_tool_nodes() assert "momentum" in tool_nodes @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.set_config") @patch("os.makedirs") def test_tool_nodes_include_macro(self, mock_makedirs, mock_set_config, mock_llm): """_create_tool_nodes should include macro tools.""" from tradingagents.graph.trading_graph import TradingAgentsGraph mock_llm.return_value = MagicMock() with patch.object(TradingAgentsGraph, "__init__", lambda self: None): graph = TradingAgentsGraph.__new__(TradingAgentsGraph) tool_nodes = graph._create_tool_nodes() assert "macro" in tool_nodes @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.set_config") @patch("os.makedirs") def test_tool_nodes_include_correlation(self, mock_makedirs, mock_set_config, mock_llm): """_create_tool_nodes should include correlation tools.""" from tradingagents.graph.trading_graph import TradingAgentsGraph mock_llm.return_value = MagicMock() with patch.object(TradingAgentsGraph, "__init__", lambda self: None): graph = TradingAgentsGraph.__new__(TradingAgentsGraph) tool_nodes = graph._create_tool_nodes() assert "correlation" in tool_nodes class TestGraphSetupDocstring: """Test that setup_graph docstring documents new analysts.""" def test_docstring_mentions_momentum(self): """setup_graph docstring should mention momentum analyst.""" from tradingagents.graph.setup import GraphSetup docstring = GraphSetup.setup_graph.__doc__ assert "momentum" in docstring.lower() def test_docstring_mentions_macro(self): """setup_graph docstring should mention macro analyst.""" from tradingagents.graph.setup import GraphSetup docstring = GraphSetup.setup_graph.__doc__ assert "macro" in docstring.lower() def test_docstring_mentions_correlation(self): """setup_graph docstring should mention correlation analyst.""" from tradingagents.graph.setup import GraphSetup docstring = GraphSetup.setup_graph.__doc__ assert "correlation" in docstring.lower() class TestLogStateNewReports: """Test that _log_state includes new report fields.""" def test_log_state_includes_momentum_report(self): """_log_state should log momentum_report if present.""" from tradingagents.graph.trading_graph import TradingAgentsGraph with patch.object(TradingAgentsGraph, "__init__", lambda self: None): graph = TradingAgentsGraph.__new__(TradingAgentsGraph) graph.log_states_dict = {} graph.ticker = "TEST" final_state = { "company_of_interest": "TEST", "trade_date": "2024-01-01", "market_report": "", "sentiment_report": "", "news_report": "", "fundamentals_report": "", "momentum_report": "Momentum analysis result", "macro_report": "", "correlation_report": "", "investment_debate_state": { "bull_history": "", "bear_history": "", "history": "", "current_response": "", "judge_decision": "", }, "trader_investment_plan": "", "risk_debate_state": { "risky_history": "", "safe_history": "", "neutral_history": "", "history": "", "judge_decision": "", }, "investment_plan": "", "final_trade_decision": "", } with patch("builtins.open", MagicMock()): with patch("json.dump"): with patch("pathlib.Path.mkdir"): graph._log_state("2024-01-01", final_state) logged = graph.log_states_dict["2024-01-01"] assert "momentum_report" in logged assert logged["momentum_report"] == "Momentum analysis result" def test_log_state_includes_macro_report(self): """_log_state should log macro_report if present.""" from tradingagents.graph.trading_graph import TradingAgentsGraph with patch.object(TradingAgentsGraph, "__init__", lambda self: None): graph = TradingAgentsGraph.__new__(TradingAgentsGraph) graph.log_states_dict = {} graph.ticker = "TEST" final_state = { "company_of_interest": "TEST", "trade_date": "2024-01-01", "market_report": "", "sentiment_report": "", "news_report": "", "fundamentals_report": "", "momentum_report": "", "macro_report": "Macro analysis result", "correlation_report": "", "investment_debate_state": { "bull_history": "", "bear_history": "", "history": "", "current_response": "", "judge_decision": "", }, "trader_investment_plan": "", "risk_debate_state": { "risky_history": "", "safe_history": "", "neutral_history": "", "history": "", "judge_decision": "", }, "investment_plan": "", "final_trade_decision": "", } with patch("builtins.open", MagicMock()): with patch("json.dump"): with patch("pathlib.Path.mkdir"): graph._log_state("2024-01-01", final_state) logged = graph.log_states_dict["2024-01-01"] assert "macro_report" in logged assert logged["macro_report"] == "Macro analysis result" def test_log_state_includes_correlation_report(self): """_log_state should log correlation_report if present.""" from tradingagents.graph.trading_graph import TradingAgentsGraph with patch.object(TradingAgentsGraph, "__init__", lambda self: None): graph = TradingAgentsGraph.__new__(TradingAgentsGraph) graph.log_states_dict = {} graph.ticker = "TEST" final_state = { "company_of_interest": "TEST", "trade_date": "2024-01-01", "market_report": "", "sentiment_report": "", "news_report": "", "fundamentals_report": "", "momentum_report": "", "macro_report": "", "correlation_report": "Correlation analysis result", "investment_debate_state": { "bull_history": "", "bear_history": "", "history": "", "current_response": "", "judge_decision": "", }, "trader_investment_plan": "", "risk_debate_state": { "risky_history": "", "safe_history": "", "neutral_history": "", "history": "", "judge_decision": "", }, "investment_plan": "", "final_trade_decision": "", } with patch("builtins.open", MagicMock()): with patch("json.dump"): with patch("pathlib.Path.mkdir"): graph._log_state("2024-01-01", final_state) logged = graph.log_states_dict["2024-01-01"] assert "correlation_report" in logged assert logged["correlation_report"] == "Correlation analysis result" def test_log_state_handles_missing_new_reports(self): """_log_state should handle missing new report fields gracefully.""" from tradingagents.graph.trading_graph import TradingAgentsGraph with patch.object(TradingAgentsGraph, "__init__", lambda self: None): graph = TradingAgentsGraph.__new__(TradingAgentsGraph) graph.log_states_dict = {} graph.ticker = "TEST" # State without new report fields (backward compatibility) final_state = { "company_of_interest": "TEST", "trade_date": "2024-01-01", "investment_debate_state": { "bull_history": "", "bear_history": "", "history": "", "current_response": "", "judge_decision": "", }, "trader_investment_plan": "", "risk_debate_state": { "risky_history": "", "safe_history": "", "neutral_history": "", "history": "", "judge_decision": "", }, "investment_plan": "", "final_trade_decision": "", } with patch("builtins.open", MagicMock()): with patch("json.dump"): with patch("pathlib.Path.mkdir"): graph._log_state("2024-01-01", final_state) logged = graph.log_states_dict["2024-01-01"] # Should default to empty string assert logged["momentum_report"] == "" assert logged["macro_report"] == "" assert logged["correlation_report"] == ""