507 lines
20 KiB
Python
507 lines
20 KiB
Python
"""Tests for Issue #17: Analyst integration into graph/setup.py workflow.
|
|
|
|
This module tests the integration of new analysts (momentum, macro, correlation)
|
|
into the TradingAgents graph workflow.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from typing import get_type_hints
|
|
import sys
|
|
|
|
# Check if langchain dependencies are available
|
|
try:
|
|
import langchain_core
|
|
LANGCHAIN_AVAILABLE = True
|
|
except ImportError:
|
|
LANGCHAIN_AVAILABLE = False
|
|
|
|
# Skip all tests if langchain not available
|
|
pytestmark = pytest.mark.skipif(
|
|
not LANGCHAIN_AVAILABLE,
|
|
reason="langchain_core not installed"
|
|
)
|
|
|
|
|
|
class TestAgentStateReports:
|
|
"""Test that AgentState has the new report fields."""
|
|
|
|
def test_agent_state_has_momentum_report(self):
|
|
"""AgentState should have momentum_report field."""
|
|
from tradingagents.agents.utils.agent_states import AgentState
|
|
hints = get_type_hints(AgentState)
|
|
assert "momentum_report" in hints
|
|
|
|
def test_agent_state_has_macro_report(self):
|
|
"""AgentState should have macro_report field."""
|
|
from tradingagents.agents.utils.agent_states import AgentState
|
|
hints = get_type_hints(AgentState)
|
|
assert "macro_report" in hints
|
|
|
|
def test_agent_state_has_correlation_report(self):
|
|
"""AgentState should have correlation_report field."""
|
|
from tradingagents.agents.utils.agent_states import AgentState
|
|
hints = get_type_hints(AgentState)
|
|
assert "correlation_report" in hints
|
|
|
|
|
|
class TestConditionalLogicMethods:
|
|
"""Test that ConditionalLogic has methods for new analysts."""
|
|
|
|
def test_should_continue_momentum_exists(self):
|
|
"""ConditionalLogic should have should_continue_momentum method."""
|
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
|
cl = ConditionalLogic()
|
|
assert hasattr(cl, "should_continue_momentum")
|
|
assert callable(cl.should_continue_momentum)
|
|
|
|
def test_should_continue_macro_exists(self):
|
|
"""ConditionalLogic should have should_continue_macro method."""
|
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
|
cl = ConditionalLogic()
|
|
assert hasattr(cl, "should_continue_macro")
|
|
assert callable(cl.should_continue_macro)
|
|
|
|
def test_should_continue_correlation_exists(self):
|
|
"""ConditionalLogic should have should_continue_correlation method."""
|
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
|
cl = ConditionalLogic()
|
|
assert hasattr(cl, "should_continue_correlation")
|
|
assert callable(cl.should_continue_correlation)
|
|
|
|
def test_momentum_conditional_returns_tools(self):
|
|
"""should_continue_momentum should return 'tools_momentum' when tool_calls exist."""
|
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
|
cl = ConditionalLogic()
|
|
|
|
mock_message = MagicMock()
|
|
mock_message.tool_calls = [{"name": "test"}]
|
|
state = {"messages": [mock_message]}
|
|
|
|
result = cl.should_continue_momentum(state)
|
|
assert result == "tools_momentum"
|
|
|
|
def test_momentum_conditional_returns_clear(self):
|
|
"""should_continue_momentum should return 'Msg Clear Momentum' when no tool_calls."""
|
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
|
cl = ConditionalLogic()
|
|
|
|
mock_message = MagicMock()
|
|
mock_message.tool_calls = []
|
|
state = {"messages": [mock_message]}
|
|
|
|
result = cl.should_continue_momentum(state)
|
|
assert result == "Msg Clear Momentum"
|
|
|
|
def test_macro_conditional_returns_tools(self):
|
|
"""should_continue_macro should return 'tools_macro' when tool_calls exist."""
|
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
|
cl = ConditionalLogic()
|
|
|
|
mock_message = MagicMock()
|
|
mock_message.tool_calls = [{"name": "test"}]
|
|
state = {"messages": [mock_message]}
|
|
|
|
result = cl.should_continue_macro(state)
|
|
assert result == "tools_macro"
|
|
|
|
def test_macro_conditional_returns_clear(self):
|
|
"""should_continue_macro should return 'Msg Clear Macro' when no tool_calls."""
|
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
|
cl = ConditionalLogic()
|
|
|
|
mock_message = MagicMock()
|
|
mock_message.tool_calls = []
|
|
state = {"messages": [mock_message]}
|
|
|
|
result = cl.should_continue_macro(state)
|
|
assert result == "Msg Clear Macro"
|
|
|
|
def test_correlation_conditional_returns_tools(self):
|
|
"""should_continue_correlation should return 'tools_correlation' when tool_calls exist."""
|
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
|
cl = ConditionalLogic()
|
|
|
|
mock_message = MagicMock()
|
|
mock_message.tool_calls = [{"name": "test"}]
|
|
state = {"messages": [mock_message]}
|
|
|
|
result = cl.should_continue_correlation(state)
|
|
assert result == "tools_correlation"
|
|
|
|
def test_correlation_conditional_returns_clear(self):
|
|
"""should_continue_correlation should return 'Msg Clear Correlation' when no tool_calls."""
|
|
from tradingagents.graph.conditional_logic import ConditionalLogic
|
|
cl = ConditionalLogic()
|
|
|
|
mock_message = MagicMock()
|
|
mock_message.tool_calls = []
|
|
state = {"messages": [mock_message]}
|
|
|
|
result = cl.should_continue_correlation(state)
|
|
assert result == "Msg Clear Correlation"
|
|
|
|
|
|
class TestAgentImports:
|
|
"""Test that new analysts are properly exported from agents module."""
|
|
|
|
def test_create_momentum_analyst_import(self):
|
|
"""create_momentum_analyst should be importable from agents."""
|
|
from tradingagents.agents import create_momentum_analyst
|
|
assert callable(create_momentum_analyst)
|
|
|
|
def test_create_macro_analyst_import(self):
|
|
"""create_macro_analyst should be importable from agents."""
|
|
from tradingagents.agents import create_macro_analyst
|
|
assert callable(create_macro_analyst)
|
|
|
|
def test_create_correlation_analyst_import(self):
|
|
"""create_correlation_analyst should be importable from agents."""
|
|
from tradingagents.agents import create_correlation_analyst
|
|
assert callable(create_correlation_analyst)
|
|
|
|
def test_create_position_sizing_manager_import(self):
|
|
"""create_position_sizing_manager should be importable from agents."""
|
|
from tradingagents.agents import create_position_sizing_manager
|
|
assert callable(create_position_sizing_manager)
|
|
|
|
|
|
class TestAnalystsModuleExports:
|
|
"""Test analysts submodule exports."""
|
|
|
|
def test_analysts_module_exports_momentum(self):
|
|
"""analysts module should export create_momentum_analyst."""
|
|
from tradingagents.agents.analysts import create_momentum_analyst
|
|
assert callable(create_momentum_analyst)
|
|
|
|
def test_analysts_module_exports_macro(self):
|
|
"""analysts module should export create_macro_analyst."""
|
|
from tradingagents.agents.analysts import create_macro_analyst
|
|
assert callable(create_macro_analyst)
|
|
|
|
def test_analysts_module_exports_correlation(self):
|
|
"""analysts module should export create_correlation_analyst."""
|
|
from tradingagents.agents.analysts import create_correlation_analyst
|
|
assert callable(create_correlation_analyst)
|
|
|
|
def test_analysts_module_all_exports(self):
|
|
"""analysts __all__ should include all analyst creators."""
|
|
from tradingagents.agents import analysts
|
|
assert "create_momentum_analyst" in analysts.__all__
|
|
assert "create_macro_analyst" in analysts.__all__
|
|
assert "create_correlation_analyst" in analysts.__all__
|
|
|
|
|
|
class TestManagersModuleExports:
|
|
"""Test managers submodule exports."""
|
|
|
|
def test_managers_module_exports_position_sizing(self):
|
|
"""managers module should export create_position_sizing_manager."""
|
|
from tradingagents.agents.managers import create_position_sizing_manager
|
|
assert callable(create_position_sizing_manager)
|
|
|
|
def test_managers_module_all_exports(self):
|
|
"""managers __all__ should include position_sizing_manager."""
|
|
from tradingagents.agents import managers
|
|
assert "create_position_sizing_manager" in managers.__all__
|
|
|
|
|
|
class TestToolImports:
|
|
"""Test that tool functions are importable from analyst modules."""
|
|
|
|
def test_momentum_tools_importable(self):
|
|
"""Momentum analyst tools should be importable."""
|
|
from tradingagents.agents.analysts.momentum_analyst import (
|
|
get_multi_timeframe_momentum,
|
|
get_adx_analysis,
|
|
get_momentum_divergence,
|
|
)
|
|
assert callable(get_multi_timeframe_momentum)
|
|
assert callable(get_adx_analysis)
|
|
assert callable(get_momentum_divergence)
|
|
|
|
def test_macro_tools_importable(self):
|
|
"""Macro analyst tools should be importable."""
|
|
from tradingagents.agents.analysts.macro_analyst import (
|
|
get_economic_regime_analysis,
|
|
get_yield_curve_analysis,
|
|
get_monetary_policy_analysis,
|
|
get_inflation_regime_analysis,
|
|
)
|
|
assert callable(get_economic_regime_analysis)
|
|
assert callable(get_yield_curve_analysis)
|
|
assert callable(get_monetary_policy_analysis)
|
|
assert callable(get_inflation_regime_analysis)
|
|
|
|
def test_correlation_tools_importable(self):
|
|
"""Correlation analyst tools should be importable."""
|
|
from tradingagents.agents.analysts.correlation_analyst import (
|
|
get_cross_asset_correlation_analysis,
|
|
get_sector_rotation_analysis,
|
|
get_correlation_matrix,
|
|
get_rolling_correlation_trend,
|
|
)
|
|
assert callable(get_cross_asset_correlation_analysis)
|
|
assert callable(get_sector_rotation_analysis)
|
|
assert callable(get_correlation_matrix)
|
|
assert callable(get_rolling_correlation_trend)
|
|
|
|
|
|
class TestTradingGraphToolNodes:
|
|
"""Test that trading_graph.py has tool nodes for new analysts."""
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.set_config")
|
|
@patch("os.makedirs")
|
|
def test_tool_nodes_include_momentum(self, mock_makedirs, mock_set_config, mock_llm):
|
|
"""_create_tool_nodes should include momentum tools."""
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
mock_llm.return_value = MagicMock()
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
|
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
|
tool_nodes = graph._create_tool_nodes()
|
|
|
|
assert "momentum" in tool_nodes
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.set_config")
|
|
@patch("os.makedirs")
|
|
def test_tool_nodes_include_macro(self, mock_makedirs, mock_set_config, mock_llm):
|
|
"""_create_tool_nodes should include macro tools."""
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
mock_llm.return_value = MagicMock()
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
|
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
|
tool_nodes = graph._create_tool_nodes()
|
|
|
|
assert "macro" in tool_nodes
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.set_config")
|
|
@patch("os.makedirs")
|
|
def test_tool_nodes_include_correlation(self, mock_makedirs, mock_set_config, mock_llm):
|
|
"""_create_tool_nodes should include correlation tools."""
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
mock_llm.return_value = MagicMock()
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
|
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
|
tool_nodes = graph._create_tool_nodes()
|
|
|
|
assert "correlation" in tool_nodes
|
|
|
|
|
|
class TestGraphSetupDocstring:
|
|
"""Test that setup_graph docstring documents new analysts."""
|
|
|
|
def test_docstring_mentions_momentum(self):
|
|
"""setup_graph docstring should mention momentum analyst."""
|
|
from tradingagents.graph.setup import GraphSetup
|
|
docstring = GraphSetup.setup_graph.__doc__
|
|
assert "momentum" in docstring.lower()
|
|
|
|
def test_docstring_mentions_macro(self):
|
|
"""setup_graph docstring should mention macro analyst."""
|
|
from tradingagents.graph.setup import GraphSetup
|
|
docstring = GraphSetup.setup_graph.__doc__
|
|
assert "macro" in docstring.lower()
|
|
|
|
def test_docstring_mentions_correlation(self):
|
|
"""setup_graph docstring should mention correlation analyst."""
|
|
from tradingagents.graph.setup import GraphSetup
|
|
docstring = GraphSetup.setup_graph.__doc__
|
|
assert "correlation" in docstring.lower()
|
|
|
|
|
|
class TestLogStateNewReports:
|
|
"""Test that _log_state includes new report fields."""
|
|
|
|
def test_log_state_includes_momentum_report(self):
|
|
"""_log_state should log momentum_report if present."""
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
|
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
|
graph.log_states_dict = {}
|
|
graph.ticker = "TEST"
|
|
|
|
final_state = {
|
|
"company_of_interest": "TEST",
|
|
"trade_date": "2024-01-01",
|
|
"market_report": "",
|
|
"sentiment_report": "",
|
|
"news_report": "",
|
|
"fundamentals_report": "",
|
|
"momentum_report": "Momentum analysis result",
|
|
"macro_report": "",
|
|
"correlation_report": "",
|
|
"investment_debate_state": {
|
|
"bull_history": "",
|
|
"bear_history": "",
|
|
"history": "",
|
|
"current_response": "",
|
|
"judge_decision": "",
|
|
},
|
|
"trader_investment_plan": "",
|
|
"risk_debate_state": {
|
|
"risky_history": "",
|
|
"safe_history": "",
|
|
"neutral_history": "",
|
|
"history": "",
|
|
"judge_decision": "",
|
|
},
|
|
"investment_plan": "",
|
|
"final_trade_decision": "",
|
|
}
|
|
|
|
with patch("builtins.open", MagicMock()):
|
|
with patch("json.dump"):
|
|
with patch("pathlib.Path.mkdir"):
|
|
graph._log_state("2024-01-01", final_state)
|
|
|
|
logged = graph.log_states_dict["2024-01-01"]
|
|
assert "momentum_report" in logged
|
|
assert logged["momentum_report"] == "Momentum analysis result"
|
|
|
|
def test_log_state_includes_macro_report(self):
|
|
"""_log_state should log macro_report if present."""
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
|
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
|
graph.log_states_dict = {}
|
|
graph.ticker = "TEST"
|
|
|
|
final_state = {
|
|
"company_of_interest": "TEST",
|
|
"trade_date": "2024-01-01",
|
|
"market_report": "",
|
|
"sentiment_report": "",
|
|
"news_report": "",
|
|
"fundamentals_report": "",
|
|
"momentum_report": "",
|
|
"macro_report": "Macro analysis result",
|
|
"correlation_report": "",
|
|
"investment_debate_state": {
|
|
"bull_history": "",
|
|
"bear_history": "",
|
|
"history": "",
|
|
"current_response": "",
|
|
"judge_decision": "",
|
|
},
|
|
"trader_investment_plan": "",
|
|
"risk_debate_state": {
|
|
"risky_history": "",
|
|
"safe_history": "",
|
|
"neutral_history": "",
|
|
"history": "",
|
|
"judge_decision": "",
|
|
},
|
|
"investment_plan": "",
|
|
"final_trade_decision": "",
|
|
}
|
|
|
|
with patch("builtins.open", MagicMock()):
|
|
with patch("json.dump"):
|
|
with patch("pathlib.Path.mkdir"):
|
|
graph._log_state("2024-01-01", final_state)
|
|
|
|
logged = graph.log_states_dict["2024-01-01"]
|
|
assert "macro_report" in logged
|
|
assert logged["macro_report"] == "Macro analysis result"
|
|
|
|
def test_log_state_includes_correlation_report(self):
|
|
"""_log_state should log correlation_report if present."""
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
|
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
|
graph.log_states_dict = {}
|
|
graph.ticker = "TEST"
|
|
|
|
final_state = {
|
|
"company_of_interest": "TEST",
|
|
"trade_date": "2024-01-01",
|
|
"market_report": "",
|
|
"sentiment_report": "",
|
|
"news_report": "",
|
|
"fundamentals_report": "",
|
|
"momentum_report": "",
|
|
"macro_report": "",
|
|
"correlation_report": "Correlation analysis result",
|
|
"investment_debate_state": {
|
|
"bull_history": "",
|
|
"bear_history": "",
|
|
"history": "",
|
|
"current_response": "",
|
|
"judge_decision": "",
|
|
},
|
|
"trader_investment_plan": "",
|
|
"risk_debate_state": {
|
|
"risky_history": "",
|
|
"safe_history": "",
|
|
"neutral_history": "",
|
|
"history": "",
|
|
"judge_decision": "",
|
|
},
|
|
"investment_plan": "",
|
|
"final_trade_decision": "",
|
|
}
|
|
|
|
with patch("builtins.open", MagicMock()):
|
|
with patch("json.dump"):
|
|
with patch("pathlib.Path.mkdir"):
|
|
graph._log_state("2024-01-01", final_state)
|
|
|
|
logged = graph.log_states_dict["2024-01-01"]
|
|
assert "correlation_report" in logged
|
|
assert logged["correlation_report"] == "Correlation analysis result"
|
|
|
|
def test_log_state_handles_missing_new_reports(self):
|
|
"""_log_state should handle missing new report fields gracefully."""
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self: None):
|
|
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
|
graph.log_states_dict = {}
|
|
graph.ticker = "TEST"
|
|
|
|
# State without new report fields (backward compatibility)
|
|
final_state = {
|
|
"company_of_interest": "TEST",
|
|
"trade_date": "2024-01-01",
|
|
"investment_debate_state": {
|
|
"bull_history": "",
|
|
"bear_history": "",
|
|
"history": "",
|
|
"current_response": "",
|
|
"judge_decision": "",
|
|
},
|
|
"trader_investment_plan": "",
|
|
"risk_debate_state": {
|
|
"risky_history": "",
|
|
"safe_history": "",
|
|
"neutral_history": "",
|
|
"history": "",
|
|
"judge_decision": "",
|
|
},
|
|
"investment_plan": "",
|
|
"final_trade_decision": "",
|
|
}
|
|
|
|
with patch("builtins.open", MagicMock()):
|
|
with patch("json.dump"):
|
|
with patch("pathlib.Path.mkdir"):
|
|
graph._log_state("2024-01-01", final_state)
|
|
|
|
logged = graph.log_states_dict["2024-01-01"]
|
|
# Should default to empty string
|
|
assert logged["momentum_report"] == ""
|
|
assert logged["macro_report"] == ""
|
|
assert logged["correlation_report"] == ""
|