TradingAgents/tests/unit/graph/test_trading_graph.py

471 lines
17 KiB
Python

"""Unit tests for TradingAgentsGraph."""
from unittest.mock import Mock, mock_open, patch
import pytest
from tradingagents.graph.trading_graph import TradingAgentsGraph
from .mock_toolkit_fix import patch_toolkit_in_test
class TestTradingAgentsGraph:
"""Test suite for TradingAgentsGraph class."""
@patch("tradingagents.dataflows.config.set_config")
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_basic(
self,
mock_toolkit,
mock_chat_openai,
mock_set_config,
sample_config,
temp_data_dir,
):
"""Test basic initialization of TradingAgentsGraph."""
# Setup
sample_config["project_dir"] = temp_data_dir
mock_llm = Mock()
mock_chat_openai.return_value = mock_llm
mock_toolkit_instance = patch_toolkit_in_test(mock_toolkit)
mock_toolkit_instance.config = sample_config
# Execute
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
graph = TradingAgentsGraph(config=sample_config)
# Verify
assert graph.config == sample_config
assert graph.debug is False
mock_set_config.assert_called_once_with(sample_config)
assert (
mock_chat_openai.call_count == 2
) # deep_thinking_llm and quick_thinking_llm
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_debug(
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test initialization with debug mode enabled."""
sample_config["project_dir"] = temp_data_dir
mock_llm = Mock()
mock_chat_openai.return_value = mock_llm
patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"):
graph = TradingAgentsGraph(debug=True, config=sample_config)
assert graph.debug is True
@patch("tradingagents.graph.trading_graph.ChatAnthropic")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_anthropic(
self,
mock_toolkit,
mock_chat_anthropic,
sample_config,
temp_data_dir,
):
"""Test initialization with Anthropic LLM provider."""
sample_config["project_dir"] = temp_data_dir
sample_config["llm_provider"] = "anthropic"
mock_llm = Mock()
mock_chat_anthropic.return_value = mock_llm
patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"):
TradingAgentsGraph(config=sample_config)
assert mock_chat_anthropic.call_count == 2
@patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_with_google(
self,
mock_toolkit,
mock_chat_google,
sample_config,
temp_data_dir,
):
"""Test initialization with Google LLM provider."""
sample_config["project_dir"] = temp_data_dir
sample_config["llm_provider"] = "google"
mock_llm = Mock()
mock_chat_google.return_value = mock_llm
patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"):
TradingAgentsGraph(config=sample_config)
assert mock_chat_google.call_count == 2
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_init_unsupported_llm_provider(
self,
mock_toolkit,
sample_config,
temp_data_dir,
):
"""Test initialization with unsupported LLM provider raises error."""
sample_config["project_dir"] = temp_data_dir
sample_config["llm_provider"] = "unsupported"
patch_toolkit_in_test(mock_toolkit)
with pytest.raises(ValueError, match="Unsupported LLM provider"):
with patch("tradingagents.dataflows.config.set_config"):
TradingAgentsGraph(config=sample_config)
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_create_tool_nodes(
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test creation of tool nodes."""
sample_config["project_dir"] = temp_data_dir
mock_llm = Mock()
mock_chat_openai.return_value = mock_llm
mock_toolkit_instance = Mock()
# Setup toolkit methods
mock_toolkit_instance.get_YFin_data_online = Mock()
mock_toolkit_instance.get_YFin_data = Mock()
mock_toolkit_instance.get_stockstats_indicators_report_online = Mock()
mock_toolkit_instance.get_stockstats_indicators_report = Mock()
patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"):
graph = TradingAgentsGraph(config=sample_config)
# Verify tool nodes are created
assert "market" in graph.tool_nodes
assert "social" in graph.tool_nodes
assert "news" in graph.tool_nodes
assert "fundamentals" in graph.tool_nodes
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_propagate_basic(
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test basic propagate functionality."""
sample_config["project_dir"] = temp_data_dir
mock_llm = Mock()
mock_chat_openai.return_value = mock_llm
patch_toolkit_in_test(mock_toolkit)
# Mock the graph and its invoke method
mock_graph = Mock()
mock_final_state = {
"company_of_interest": "AAPL",
"trade_date": "2024-05-10",
"market_report": "",
"sentiment_report": "",
"news_report": "",
"fundamentals_report": "",
"investment_debate_state": {
"bull_history": [],
"bear_history": [],
"history": [],
"current_response": "",
"judge_decision": "",
},
"trader_investment_plan": "",
"risk_debate_state": {
"risky_history": [],
"safe_history": [],
"neutral_history": [],
"history": [],
"judge_decision": "",
},
"investment_plan": "",
"final_trade_decision": "HOLD",
}
mock_graph.invoke.return_value = mock_final_state
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"):
graph = TradingAgentsGraph(config=sample_config)
graph.graph = mock_graph
# Mock the propagator and signal processor
graph.propagator.create_initial_state = Mock(
return_value={"test": "state"},
)
graph.propagator.get_graph_args = Mock(return_value={})
graph.signal_processor.process_signal = Mock(return_value="HOLD")
# Execute
with patch("builtins.open", create=True), patch("json.dump"):
final_state, decision = graph.propagate("AAPL", "2024-05-10")
# Verify
assert final_state == mock_final_state
assert decision == "HOLD"
assert graph.ticker == "AAPL"
assert graph.curr_state == mock_final_state
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_propagate_debug_mode(
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test propagate in debug mode."""
sample_config["project_dir"] = temp_data_dir
mock_llm = Mock()
mock_chat_openai.return_value = mock_llm
patch_toolkit_in_test(mock_toolkit)
# Mock the graph stream method for debug mode
mock_graph = Mock()
mock_chunk = {"messages": [Mock()]}
mock_chunk["messages"][0].pretty_print = Mock()
mock_graph.stream.return_value = [mock_chunk, mock_chunk] # Multiple chunks
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"):
graph = TradingAgentsGraph(debug=True, config=sample_config)
graph.graph = mock_graph
# Mock other components
graph.propagator.create_initial_state = Mock(
return_value={"test": "state"},
)
graph.propagator.get_graph_args = Mock(return_value={})
graph.signal_processor.process_signal = Mock(return_value="BUY")
# Execute
with patch("builtins.open", create=True), patch("json.dump"):
final_state, decision = graph.propagate("TSLA", "2024-05-15")
# Verify debug mode was used
mock_graph.stream.assert_called_once()
assert graph.debug is True
assert decision == "BUY"
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_log_state(
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test state logging functionality."""
sample_config["project_dir"] = temp_data_dir
mock_llm = Mock()
mock_chat_openai.return_value = mock_llm
patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"):
graph = TradingAgentsGraph(config=sample_config)
graph.ticker = "NVDA"
# Create a mock final state
final_state = {
"company_of_interest": "NVDA",
"trade_date": "2024-05-20",
"market_report": "Market looking good",
"sentiment_report": "Positive sentiment",
"news_report": "Good news",
"fundamentals_report": "Strong fundamentals",
"investment_debate_state": {
"bull_history": [],
"bear_history": [],
"history": [],
"current_response": "",
"judge_decision": "BUY",
},
"trader_investment_plan": "Buy 100 shares",
"risk_debate_state": {
"risky_history": [],
"safe_history": [],
"neutral_history": [],
"history": [],
"judge_decision": "LOW_RISK",
},
"investment_plan": "Execute buy order",
"final_trade_decision": "BUY",
}
# Mock file operations
with patch("pathlib.Path.mkdir"), patch("builtins.open", mock_open()):
with patch("json.dump"):
graph._log_state("2024-05-20", final_state)
# Verify logging occurred
assert "2024-05-20" in graph.log_states_dict
logged_state = graph.log_states_dict["2024-05-20"]
assert logged_state["company_of_interest"] == "NVDA"
assert logged_state["final_trade_decision"] == "BUY"
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_reflect_and_remember(
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test reflection and memory update functionality."""
sample_config["project_dir"] = temp_data_dir
mock_llm = Mock()
mock_chat_openai.return_value = mock_llm
patch_toolkit_in_test(mock_toolkit)
with (
patch(
"tradingagents.graph.trading_graph.FinancialSituationMemory",
),
patch("tradingagents.graph.trading_graph.set_config"),
):
graph = TradingAgentsGraph(config=sample_config)
# Set up current state
graph.curr_state = {"test": "state"}
# Mock reflector methods
graph.reflector.reflect_bull_researcher = Mock()
graph.reflector.reflect_bear_researcher = Mock()
graph.reflector.reflect_trader = Mock()
graph.reflector.reflect_invest_judge = Mock()
graph.reflector.reflect_risk_manager = Mock()
returns_losses = {"return": 0.05, "loss": -0.02}
# Execute
graph.reflect_and_remember(returns_losses)
# Verify all reflection methods were called
graph.reflector.reflect_bull_researcher.assert_called_once()
graph.reflector.reflect_bear_researcher.assert_called_once()
graph.reflector.reflect_trader.assert_called_once()
graph.reflector.reflect_invest_judge.assert_called_once()
graph.reflector.reflect_risk_manager.assert_called_once()
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_process_signal(
self,
mock_toolkit,
mock_chat_openai,
sample_config,
temp_data_dir,
):
"""Test signal processing functionality."""
sample_config["project_dir"] = temp_data_dir
mock_llm = Mock()
mock_chat_openai.return_value = mock_llm
patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"):
graph = TradingAgentsGraph(config=sample_config)
graph.signal_processor.process_signal = Mock(return_value="BUY")
full_signal = "Based on analysis, recommend BUY with confidence 0.8"
result = graph.process_signal(full_signal)
assert result == "BUY"
graph.signal_processor.process_signal.assert_called_once_with(full_signal)
@pytest.mark.parametrize(
"selected_analysts",
[
["market"],
["market", "social"],
["market", "social", "news"],
["market", "social", "news", "fundamentals"],
],
)
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_selected_analysts_configuration(
self,
mock_toolkit,
mock_chat_openai,
selected_analysts,
sample_config,
temp_data_dir,
):
"""Test different analyst configurations."""
sample_config["project_dir"] = temp_data_dir
mock_llm = Mock()
mock_chat_openai.return_value = mock_llm
patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"):
TradingAgentsGraph(
selected_analysts=selected_analysts,
config=sample_config,
)
# Verify graph was set up with selected analysts
# (The actual setup_graph method would be mocked in a real implementation)
assert len(selected_analysts) >= 1 # Basic validation
class TestTradingAgentsGraphErrorHandling:
"""Test error handling in TradingAgentsGraph."""
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_invalid_config_handling(self, mock_toolkit):
"""Test handling of invalid configuration."""
invalid_config = {"invalid_key": "invalid_value"}
patch_toolkit_in_test(mock_toolkit)
# This should still work as the class should use defaults for missing keys
with patch("tradingagents.dataflows.config.set_config"):
with pytest.raises(
KeyError,
): # Should fail when trying to access missing config keys
TradingAgentsGraph(config=invalid_config)
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
@patch("tradingagents.graph.trading_graph.Toolkit")
def test_directory_creation_failure(
self,
mock_toolkit,
mock_chat_openai,
sample_config,
):
"""Test handling when directory creation fails."""
sample_config["project_dir"] = "/invalid/path/that/cannot/be/created"
mock_llm = Mock()
mock_chat_openai.return_value = mock_llm
patch_toolkit_in_test(mock_toolkit)
# Should handle directory creation gracefully or raise appropriate error
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"):
# This might raise PermissionError or similar, depending on implementation
try:
TradingAgentsGraph(config=sample_config)
except (PermissionError, OSError):
# This is expected for invalid paths
pass