From 0ab8c8fc46848e97d23a62f4328980984d9638aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=90=E8=97=A4=E5=84=AA=E4=B8=80?= Date: Mon, 11 Aug 2025 10:42:13 +0900 Subject: [PATCH] Fix CI/CD test failures - Apply Black formatting to all test files - Fix Mock objects to include tool_calls attribute for len() checks - Add proper __name__ attributes to mock toolkit methods for @tool decorator - Create mock_toolkit_fix helper for consistent toolkit mocking All tests should now pass with proper mocking setup. --- test_graph_fix.py | 100 ++++++++++ tests/conftest.py | 12 +- .../agents/test_market_analyst_extended.py | 187 +++++++++--------- tests/unit/dataflows/test_utils.py | 2 +- tests/unit/graph/mock_toolkit_fix.py | 54 +++++ tests/unit/graph/test_propagation.py | 75 ++++--- tests/unit/graph/test_reflection.py | 91 ++++----- tests/unit/graph/test_signal_processing.py | 24 +-- tests/unit/graph/test_trading_graph.py | 30 +-- 9 files changed, 368 insertions(+), 207 deletions(-) create mode 100644 test_graph_fix.py create mode 100644 tests/unit/graph/mock_toolkit_fix.py diff --git a/test_graph_fix.py b/test_graph_fix.py new file mode 100644 index 00000000..da1a60ba --- /dev/null +++ b/test_graph_fix.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python +"""Test that mock toolkit fixes work for TradingAgentsGraph.""" + +from unittest.mock import Mock, patch +import sys +import os + +# Add project root to path +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from tests.unit.graph.mock_toolkit_fix import create_mock_toolkit_with_tools + + +def test_mock_toolkit_has_all_methods(): + """Test that the mock toolkit has all required methods.""" + toolkit = create_mock_toolkit_with_tools() + + required_methods = [ + "get_YFin_data", + "get_YFin_data_online", + "get_stockstats_indicators_report", + "get_stockstats_indicators_report_online", + "get_reddit_stock_info", + "get_stock_news_openai", + ] + + for method_name in required_methods: + assert hasattr(toolkit, method_name), f"Missing {method_name}" + method = getattr(toolkit, method_name) + assert hasattr(method, '__name__'), f"{method_name} missing __name__" + assert method.__name__ == method_name, f"{method_name} has wrong __name__" + assert callable(method), f"{method_name} is not callable" + + print("✓ Mock toolkit has all required methods with proper attributes") + return True + + +def test_tool_node_creation(): + """Test that ToolNode can be created with mocked toolkit methods.""" + # Mock the ToolNode class + with patch("langgraph.prebuilt.ToolNode") as MockToolNode: + MockToolNode.return_value = Mock() + + toolkit = create_mock_toolkit_with_tools() + + # Simulate creating tool nodes like in TradingAgentsGraph + from langgraph.prebuilt import ToolNode + + tool_node = ToolNode([ + toolkit.get_YFin_data, + toolkit.get_stockstats_indicators_report, + ]) + + # Should not raise an error + assert MockToolNode.called + print("✓ ToolNode can be created with mocked toolkit methods") + return True + + +def test_tool_decorator(): + """Test that @tool decorator works with mocked functions.""" + toolkit = create_mock_toolkit_with_tools() + + # The @tool decorator expects __name__ attribute + for attr_name in dir(toolkit): + if attr_name.startswith('get_'): + method = getattr(toolkit, attr_name) + assert hasattr(method, '__name__'), f"{attr_name} missing __name__" + + print("✓ All toolkit methods are compatible with @tool decorator") + return True + + +if __name__ == "__main__": + print("Testing mock toolkit fixes for TradingAgentsGraph...") + print("-" * 50) + + tests = [ + test_mock_toolkit_has_all_methods, + test_tool_node_creation, + test_tool_decorator, + ] + + all_passed = True + for test in tests: + try: + if not test(): + all_passed = False + print(f"✗ {test.__name__} failed") + except Exception as e: + all_passed = False + print(f"✗ {test.__name__} raised exception: {e}") + import traceback + traceback.print_exc() + + print("-" * 50) + if all_passed: + print("✅ All tests passed! TradingAgentsGraph mock fixes are working.") + else: + print("❌ Some tests failed. Check the output above.") \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index c62588ec..81c953d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,7 +31,13 @@ def mock_llm(): """Mock LLM for testing.""" mock = Mock() mock.model_name = "test-model" - mock.invoke.return_value = Mock(content="Test response") + + # Create a mock result with tool_calls attribute + mock_result = Mock() + mock_result.content = "Test response" + mock_result.tool_calls = [] # Add tool_calls attribute for len() check + + mock.invoke.return_value = mock_result mock.bind_tools.return_value = mock return mock @@ -154,7 +160,9 @@ def mock_toolkit(): side_effect=mock_get_stockstats_indicators_report ) toolkit.get_stockstats_indicators_report.name = "get_stockstats_indicators_report" - toolkit.get_stockstats_indicators_report.__name__ = "get_stockstats_indicators_report" + toolkit.get_stockstats_indicators_report.__name__ = ( + "get_stockstats_indicators_report" + ) toolkit.get_stockstats_indicators_report_online = Mock( side_effect=mock_get_stockstats_indicators_report_online diff --git a/tests/unit/agents/test_market_analyst_extended.py b/tests/unit/agents/test_market_analyst_extended.py index aa35a447..4d66bc42 100644 --- a/tests/unit/agents/test_market_analyst_extended.py +++ b/tests/unit/agents/test_market_analyst_extended.py @@ -12,73 +12,87 @@ class TestMarketAnalystExtended: """Extended mock LLM with more functionality.""" mock = Mock() mock.model_name = "test-model" - + # Create a mock chain mock_chain = Mock() mock_chain.invoke = Mock() mock.bind_tools = Mock(return_value=mock_chain) - + return mock - @pytest.fixture + @pytest.fixture def mock_toolkit_extended(self): """Extended mock toolkit with all methods.""" toolkit = Mock() toolkit.config = {"online_tools": False} - + # Create mock functions with proper attributes def mock_yfin(): return "YFin data" - + def mock_stockstats(): return "Stockstats data" - + toolkit.get_YFin_data = Mock(side_effect=mock_yfin) toolkit.get_YFin_data.__name__ = "get_YFin_data" toolkit.get_YFin_data.name = "get_YFin_data" - + toolkit.get_stockstats_indicators_report = Mock(side_effect=mock_stockstats) - toolkit.get_stockstats_indicators_report.__name__ = "get_stockstats_indicators_report" - toolkit.get_stockstats_indicators_report.name = "get_stockstats_indicators_report" - + toolkit.get_stockstats_indicators_report.__name__ = ( + "get_stockstats_indicators_report" + ) + toolkit.get_stockstats_indicators_report.name = ( + "get_stockstats_indicators_report" + ) + # Online versions toolkit.get_YFin_data_online = Mock(side_effect=mock_yfin) toolkit.get_YFin_data_online.__name__ = "get_YFin_data_online" toolkit.get_YFin_data_online.name = "get_YFin_data_online" - - toolkit.get_stockstats_indicators_report_online = Mock(side_effect=mock_stockstats) - toolkit.get_stockstats_indicators_report_online.__name__ = "get_stockstats_indicators_report_online" - toolkit.get_stockstats_indicators_report_online.name = "get_stockstats_indicators_report_online" - + + toolkit.get_stockstats_indicators_report_online = Mock( + side_effect=mock_stockstats + ) + toolkit.get_stockstats_indicators_report_online.__name__ = ( + "get_stockstats_indicators_report_online" + ) + toolkit.get_stockstats_indicators_report_online.name = ( + "get_stockstats_indicators_report_online" + ) + return toolkit - def test_market_analyst_system_message(self, mock_llm_extended, mock_toolkit_extended): + def test_market_analyst_system_message( + self, mock_llm_extended, mock_toolkit_extended + ): """Test that system message is properly formatted.""" # This would normally import and test the actual function # For now, we test the mock behavior - + state = { "company_of_interest": "AAPL", "trade_date": "2024-05-10", - "messages": [] + "messages": [], } - + # Simulate creating analyst mock_analyst = Mock() mock_analyst.return_value = {"messages": [], "market_report": "Test report"} - + result = mock_analyst(state) assert "market_report" in result assert "messages" in result - def test_market_analyst_with_multiple_indicators(self, mock_llm_extended, mock_toolkit_extended): + def test_market_analyst_with_multiple_indicators( + self, mock_llm_extended, mock_toolkit_extended + ): """Test analyst with multiple technical indicators.""" state = { "company_of_interest": "TSLA", "trade_date": "2024-05-15", - "messages": [] + "messages": [], } - + # Mock result with multiple indicators mock_result = Mock() mock_result.content = """ @@ -90,95 +104,93 @@ class TestMarketAnalystExtended: - Volume: Above average """ mock_result.tool_calls = [] - + mock_llm_extended.bind_tools.return_value.invoke.return_value = mock_result - + # Create mock analyst function def mock_analyst(state): - return { - "messages": [mock_result], - "market_report": mock_result.content - } - + return {"messages": [mock_result], "market_report": mock_result.content} + result = mock_analyst(state) assert "RSI" in result["market_report"] assert "MACD" in result["market_report"] assert "Bollinger" in result["market_report"] - def test_market_analyst_error_handling(self, mock_llm_extended, mock_toolkit_extended): + def test_market_analyst_error_handling( + self, mock_llm_extended, mock_toolkit_extended + ): """Test error handling in market analyst.""" state = { "company_of_interest": "INVALID", "trade_date": "2024-05-10", - "messages": [] + "messages": [], } - + # Mock error scenario - mock_llm_extended.bind_tools.return_value.invoke.side_effect = Exception("API Error") - + mock_llm_extended.bind_tools.return_value.invoke.side_effect = Exception( + "API Error" + ) + # Create analyst with error handling def mock_analyst_with_error_handling(state): try: # Would call actual analyst here raise Exception("API Error") except Exception: - return { - "messages": [], - "market_report": "Error analyzing market data" - } - + return {"messages": [], "market_report": "Error analyzing market data"} + result = mock_analyst_with_error_handling(state) assert result["market_report"] == "Error analyzing market data" - def test_market_analyst_date_formatting(self, mock_llm_extended, mock_toolkit_extended): + def test_market_analyst_date_formatting( + self, mock_llm_extended, mock_toolkit_extended + ): """Test various date formats in market analyst.""" test_dates = [ "2024-01-01", "2024-12-31", "2024-05-15", ] - + for date in test_dates: - state = { - "company_of_interest": "AAPL", - "trade_date": date, - "messages": [] - } - + state = {"company_of_interest": "AAPL", "trade_date": date, "messages": []} + mock_result = Mock() mock_result.content = f"Analysis for {date}" mock_result.tool_calls = [] - + def mock_analyst(state): return { "messages": [mock_result], - "market_report": f"Analysis for {state['trade_date']}" + "market_report": f"Analysis for {state['trade_date']}", } - + result = mock_analyst(state) assert date in result["market_report"] - def test_market_analyst_ticker_variations(self, mock_llm_extended, mock_toolkit_extended): + def test_market_analyst_ticker_variations( + self, mock_llm_extended, mock_toolkit_extended + ): """Test analyst with various ticker symbols.""" tickers = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA"] - + for ticker in tickers: state = { "company_of_interest": ticker, "trade_date": "2024-05-10", - "messages": [] + "messages": [], } - + mock_result = Mock() mock_result.content = f"Analysis for {ticker}" mock_result.tool_calls = [] - + def mock_analyst(state): return { "messages": [mock_result], - "market_report": f"Analysis for {state['company_of_interest']}" + "market_report": f"Analysis for {state['company_of_interest']}", } - + result = mock_analyst(state) assert ticker in result["market_report"] @@ -187,23 +199,23 @@ class TestMarketAnalystExtended: # Test offline configuration toolkit_offline = Mock() toolkit_offline.config = {"online_tools": False} - + def mock_offline(): return "Offline data" - + toolkit_offline.get_YFin_data = Mock(side_effect=mock_offline) toolkit_offline.get_YFin_data.__name__ = "get_YFin_data" - - # Test online configuration + + # Test online configuration toolkit_online = Mock() toolkit_online.config = {"online_tools": True} - + def mock_online(): return "Online data" - + toolkit_online.get_YFin_data_online = Mock(side_effect=mock_online) toolkit_online.get_YFin_data_online.__name__ = "get_YFin_data_online" - + # Both should work correctly assert toolkit_offline.config["online_tools"] is False assert toolkit_online.config["online_tools"] is True @@ -212,59 +224,48 @@ class TestMarketAnalystExtended: def test_market_analyst_empty_state(self, mock_llm_extended, mock_toolkit_extended): """Test analyst with minimal/empty state.""" - state = { - "company_of_interest": "", - "trade_date": "", - "messages": [] - } - + state = {"company_of_interest": "", "trade_date": "", "messages": []} + mock_result = Mock() mock_result.content = "No data available" mock_result.tool_calls = [] - + def mock_analyst(state): if not state["company_of_interest"] or not state["trade_date"]: - return { - "messages": [], - "market_report": "No data available" - } - return { - "messages": [mock_result], - "market_report": mock_result.content - } - + return {"messages": [], "market_report": "No data available"} + return {"messages": [mock_result], "market_report": mock_result.content} + result = mock_analyst(state) assert result["market_report"] == "No data available" - def test_market_analyst_tool_calls_tracking(self, mock_llm_extended, mock_toolkit_extended): + def test_market_analyst_tool_calls_tracking( + self, mock_llm_extended, mock_toolkit_extended + ): """Test tracking of tool calls in market analyst.""" state = { "company_of_interest": "AAPL", "trade_date": "2024-05-10", - "messages": [] + "messages": [], } - + # Mock result with tool calls mock_tool_call = Mock() mock_tool_call.function.name = "get_YFin_data" mock_tool_call.function.arguments = '{"ticker": "AAPL"}' - + mock_result = Mock() mock_result.content = "" mock_result.tool_calls = [mock_tool_call] - + mock_llm_extended.bind_tools.return_value.invoke.return_value = mock_result - + def mock_analyst(state): result = mock_llm_extended.bind_tools([]).invoke(state["messages"]) # When tool_calls exist, market_report should be empty report = "" if result.tool_calls else result.content - return { - "messages": [result], - "market_report": report - } - + return {"messages": [result], "market_report": report} + result = mock_analyst(state) assert result["market_report"] == "" # Empty when tool calls exist assert len(result["messages"]) == 1 - assert result["messages"][0].tool_calls == [mock_tool_call] \ No newline at end of file + assert result["messages"][0].tool_calls == [mock_tool_call] diff --git a/tests/unit/dataflows/test_utils.py b/tests/unit/dataflows/test_utils.py index 4241e64d..4d18fac8 100644 --- a/tests/unit/dataflows/test_utils.py +++ b/tests/unit/dataflows/test_utils.py @@ -12,4 +12,4 @@ class TestDataflowsUtils: """Placeholder test to ensure test file is valid.""" assert True - # Add more tests here as needed for utils.py functions \ No newline at end of file + # Add more tests here as needed for utils.py functions diff --git a/tests/unit/graph/mock_toolkit_fix.py b/tests/unit/graph/mock_toolkit_fix.py new file mode 100644 index 00000000..9e5a3005 --- /dev/null +++ b/tests/unit/graph/mock_toolkit_fix.py @@ -0,0 +1,54 @@ +"""Helper to create properly mocked toolkit for test_trading_graph.""" + +from unittest.mock import Mock + + +def create_mock_toolkit_with_tools(): + """Create a mock toolkit with all necessary tool methods.""" + toolkit = Mock() + toolkit.config = {"online_tools": False} + + # List of all methods that need to be mocked + tool_methods = [ + # Market tools + "get_YFin_data", + "get_YFin_data_online", + "get_stockstats_indicators_report", + "get_stockstats_indicators_report_online", + # Social tools + "get_reddit_stock_info", + "get_stock_news_openai", + # News tools + "get_global_news_openai", + "get_google_news", + "get_finnhub_news", + "get_reddit_news", + # Fundamentals tools + "get_simfin_cashflow", + "get_simfin_income_stmt", + "get_simfin_balance_sheet", + "get_finnhub_basic_financials", + ] + + # Create mock for each method with proper __name__ attribute + for method_name in tool_methods: + # Create a function with the right name + def mock_func(): + return f"Mock {method_name} data" + + # Create Mock wrapping the function + mock_method = Mock(side_effect=mock_func) + mock_method.__name__ = method_name + mock_method.name = method_name + + # Set it on the toolkit + setattr(toolkit, method_name, mock_method) + + return toolkit + + +def patch_toolkit_in_test(mock_toolkit): + """Configure the mock_toolkit patch to return a properly mocked instance.""" + mock_instance = create_mock_toolkit_with_tools() + mock_toolkit.return_value = mock_instance + return mock_instance \ No newline at end of file diff --git a/tests/unit/graph/test_propagation.py b/tests/unit/graph/test_propagation.py index ea9e467f..5cfcd43d 100644 --- a/tests/unit/graph/test_propagation.py +++ b/tests/unit/graph/test_propagation.py @@ -13,16 +13,16 @@ class TestPropagator: propagator = Mock() propagator.create_initial_state = Mock() propagator.get_graph_args = Mock() - - assert hasattr(propagator, 'create_initial_state') - assert hasattr(propagator, 'get_graph_args') + + assert hasattr(propagator, "create_initial_state") + assert hasattr(propagator, "get_graph_args") assert callable(propagator.create_initial_state) assert callable(propagator.get_graph_args) def test_create_initial_state(self): """Test creating initial state for propagation.""" propagator = Mock() - + # Mock the create_initial_state method expected_state = { "company_of_interest": "AAPL", @@ -50,12 +50,12 @@ class TestPropagator: "investment_plan": "", "final_trade_decision": "", } - + propagator.create_initial_state = Mock(return_value=expected_state) - + # Test state = propagator.create_initial_state("AAPL", "2024-05-10") - + assert state["company_of_interest"] == "AAPL" assert state["trade_date"] == "2024-05-10" assert state["messages"] == [] @@ -66,18 +66,18 @@ class TestPropagator: def test_get_graph_args(self): """Test getting graph arguments.""" propagator = Mock() - + # Mock the get_graph_args method expected_args = { "recursion_limit": 100, "config": {"tags": ["tradingagents"]}, } - + propagator.get_graph_args = Mock(return_value=expected_args) - + # Test args = propagator.get_graph_args() - + assert "recursion_limit" in args assert "config" in args assert args["recursion_limit"] == 100 @@ -86,63 +86,56 @@ class TestPropagator: def test_propagate_with_different_tickers(self): """Test propagation with different ticker symbols.""" propagator = Mock() - + tickers = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA"] - + for ticker in tickers: state = { "company_of_interest": ticker, "trade_date": "2024-05-10", - "messages": [] + "messages": [], } propagator.create_initial_state = Mock(return_value=state) - + result = propagator.create_initial_state(ticker, "2024-05-10") assert result["company_of_interest"] == ticker def test_propagate_with_different_dates(self): """Test propagation with different dates.""" propagator = Mock() - + dates = ["2024-01-01", "2024-06-15", "2024-12-31"] - + for date in dates: - state = { - "company_of_interest": "AAPL", - "trade_date": date, - "messages": [] - } + state = {"company_of_interest": "AAPL", "trade_date": date, "messages": []} propagator.create_initial_state = Mock(return_value=state) - + result = propagator.create_initial_state("AAPL", date) assert result["trade_date"] == date def test_propagate_error_handling(self): """Test error handling in propagation.""" propagator = Mock() - + # Simulate error propagator.create_initial_state = Mock(side_effect=ValueError("Invalid ticker")) - + with pytest.raises(ValueError): propagator.create_initial_state("INVALID", "2024-05-10") - + propagator.create_initial_state.assert_called_once() def test_graph_args_with_custom_config(self): """Test graph args with custom configuration.""" propagator = Mock() - + custom_config = { "recursion_limit": 200, - "config": { - "tags": ["custom", "test"], - "metadata": {"version": "1.0"} - } + "config": {"tags": ["custom", "test"], "metadata": {"version": "1.0"}}, } - + propagator.get_graph_args = Mock(return_value=custom_config) - + args = propagator.get_graph_args() assert args["recursion_limit"] == 200 assert "custom" in args["config"]["tags"] @@ -151,10 +144,10 @@ class TestPropagator: def test_initial_state_completeness(self): """Test that initial state contains all required fields.""" propagator = Mock() - + required_fields = [ "company_of_interest", - "trade_date", + "trade_date", "messages", "market_report", "sentiment_report", @@ -164,17 +157,17 @@ class TestPropagator: "trader_investment_plan", "risk_debate_state", "investment_plan", - "final_trade_decision" + "final_trade_decision", ] - + state = {field: "" for field in required_fields} state["messages"] = [] state["investment_debate_state"] = {} state["risk_debate_state"] = {} - + propagator.create_initial_state = Mock(return_value=state) - + result = propagator.create_initial_state("AAPL", "2024-05-10") - + for field in required_fields: - assert field in result, f"Missing required field: {field}" \ No newline at end of file + assert field in result, f"Missing required field: {field}" diff --git a/tests/unit/graph/test_reflection.py b/tests/unit/graph/test_reflection.py index 42ef427a..599b77f9 100644 --- a/tests/unit/graph/test_reflection.py +++ b/tests/unit/graph/test_reflection.py @@ -47,18 +47,18 @@ class TestReflector: """Test Reflector initialization.""" reflector = Mock() reflector.llm = mock_llm - + assert reflector.llm == mock_llm def test_reflect_bull_researcher(self, mock_llm, mock_memory, sample_state): """Test reflection for bull researcher.""" reflector = Mock() reflector.reflect_bull_researcher = Mock() - + returns_losses = {"return": 0.05, "loss": -0.02} - + reflector.reflect_bull_researcher(sample_state, returns_losses, mock_memory) - + reflector.reflect_bull_researcher.assert_called_once_with( sample_state, returns_losses, mock_memory ) @@ -67,83 +67,87 @@ class TestReflector: """Test reflection for bear researcher.""" reflector = Mock() reflector.reflect_bear_researcher = Mock() - + returns_losses = {"return": -0.03, "loss": -0.05} - + reflector.reflect_bear_researcher(sample_state, returns_losses, mock_memory) - + reflector.reflect_bear_researcher.assert_called_once() def test_reflect_trader(self, mock_llm, mock_memory, sample_state): """Test reflection for trader.""" reflector = Mock() reflector.reflect_trader = Mock() - + returns_losses = {"return": 0.10, "loss": 0.0} - + reflector.reflect_trader(sample_state, returns_losses, mock_memory) - + reflector.reflect_trader.assert_called_once() def test_reflect_invest_judge(self, mock_llm, mock_memory, sample_state): """Test reflection for investment judge.""" reflector = Mock() reflector.reflect_invest_judge = Mock() - + returns_losses = {"return": 0.02, "loss": -0.01} - + reflector.reflect_invest_judge(sample_state, returns_losses, mock_memory) - + reflector.reflect_invest_judge.assert_called_once() def test_reflect_risk_manager(self, mock_llm, mock_memory, sample_state): """Test reflection for risk manager.""" reflector = Mock() reflector.reflect_risk_manager = Mock() - + returns_losses = {"return": -0.05, "loss": -0.10} - + reflector.reflect_risk_manager(sample_state, returns_losses, mock_memory) - + reflector.reflect_risk_manager.assert_called_once() - def test_reflection_with_positive_returns(self, mock_llm, mock_memory, sample_state): + def test_reflection_with_positive_returns( + self, mock_llm, mock_memory, sample_state + ): """Test reflection with positive returns.""" reflector = Mock() - + # Mock all reflection methods reflector.reflect_bull_researcher = Mock(return_value="Positive reflection") reflector.reflect_bear_researcher = Mock(return_value="Positive reflection") reflector.reflect_trader = Mock(return_value="Positive reflection") - + returns_losses = {"return": 0.15, "loss": 0.0} - + # Call all reflections reflector.reflect_bull_researcher(sample_state, returns_losses, mock_memory) reflector.reflect_bear_researcher(sample_state, returns_losses, mock_memory) reflector.reflect_trader(sample_state, returns_losses, mock_memory) - + # Verify all were called assert reflector.reflect_bull_researcher.called assert reflector.reflect_bear_researcher.called assert reflector.reflect_trader.called - def test_reflection_with_negative_returns(self, mock_llm, mock_memory, sample_state): + def test_reflection_with_negative_returns( + self, mock_llm, mock_memory, sample_state + ): """Test reflection with negative returns.""" reflector = Mock() - + # Mock reflection methods reflector.reflect_bull_researcher = Mock(return_value="Negative reflection") reflector.reflect_bear_researcher = Mock(return_value="Negative reflection") reflector.reflect_risk_manager = Mock(return_value="Risk reflection") - + returns_losses = {"return": -0.08, "loss": -0.15} - + # Call reflections reflector.reflect_bull_researcher(sample_state, returns_losses, mock_memory) reflector.reflect_bear_researcher(sample_state, returns_losses, mock_memory) reflector.reflect_risk_manager(sample_state, returns_losses, mock_memory) - + # Verify all were called assert reflector.reflect_bull_researcher.call_count == 1 assert reflector.reflect_bear_researcher.call_count == 1 @@ -152,47 +156,46 @@ class TestReflector: def test_reflection_memory_update(self, mock_llm, mock_memory): """Test that reflection updates memory correctly.""" reflector = Mock() - + def mock_reflect(state, returns, memory): reflection = f"Reflection for {state['company_of_interest']}" memory.add_memory(reflection) return reflection - + reflector.reflect_trader = Mock(side_effect=mock_reflect) - + state = {"company_of_interest": "TSLA"} returns_losses = {"return": 0.05, "loss": 0.0} - + reflector.reflect_trader(state, returns_losses, mock_memory) - + mock_memory.add_memory.assert_called_once() def test_reflection_with_different_decisions(self, mock_llm, mock_memory): """Test reflection with different trading decisions.""" reflector = Mock() reflector.reflect_trader = Mock() - + decisions = ["BUY", "SELL", "HOLD"] - + for decision in decisions: - state = { - "final_trade_decision": decision, - "company_of_interest": "AAPL" - } + state = {"final_trade_decision": decision, "company_of_interest": "AAPL"} returns_losses = {"return": 0.03, "loss": -0.01} - + reflector.reflect_trader(state, returns_losses, mock_memory) - + assert reflector.reflect_trader.call_count == 3 def test_reflection_error_handling(self, mock_llm, mock_memory, sample_state): """Test error handling in reflection.""" reflector = Mock() - + # Simulate error in reflection - reflector.reflect_bull_researcher = Mock(side_effect=Exception("Reflection error")) - + reflector.reflect_bull_researcher = Mock( + side_effect=Exception("Reflection error") + ) + with pytest.raises(Exception): reflector.reflect_bull_researcher(sample_state, {}, mock_memory) - - reflector.reflect_bull_researcher.assert_called_once() \ No newline at end of file + + reflector.reflect_bull_researcher.assert_called_once() diff --git a/tests/unit/graph/test_signal_processing.py b/tests/unit/graph/test_signal_processing.py index 3f723488..8254c043 100644 --- a/tests/unit/graph/test_signal_processing.py +++ b/tests/unit/graph/test_signal_processing.py @@ -10,14 +10,16 @@ class TestSignalProcessor: def test_signal_processor_initialization(self): """Test SignalProcessor initialization.""" mock_llm = Mock() - + # Import with mocked dependencies to avoid pandas import - with patch('sys.modules', {'pandas': Mock(), 'yfinance': Mock(), 'openai': Mock()}): + with patch( + "sys.modules", {"pandas": Mock(), "yfinance": Mock(), "openai": Mock()} + ): # This would normally import SignalProcessor # from tradingagents.graph.signal_processing import SignalProcessor # processor = SignalProcessor(mock_llm) pass - + assert True # Placeholder def test_process_signal_buy(self): @@ -25,7 +27,7 @@ class TestSignalProcessor: # Create mock processor processor = Mock() processor.process_signal = Mock(return_value="BUY") - + result = processor.process_signal("Recommend BUY based on analysis") assert result == "BUY" processor.process_signal.assert_called_once() @@ -34,7 +36,7 @@ class TestSignalProcessor: """Test processing SELL signal.""" processor = Mock() processor.process_signal = Mock(return_value="SELL") - + result = processor.process_signal("Recommend SELL based on analysis") assert result == "SELL" @@ -42,7 +44,7 @@ class TestSignalProcessor: """Test processing HOLD signal.""" processor = Mock() processor.process_signal = Mock(return_value="HOLD") - + result = processor.process_signal("Recommend HOLD based on analysis") assert result == "HOLD" @@ -50,7 +52,7 @@ class TestSignalProcessor: """Test processing signal with confidence score.""" processor = Mock() processor.process_signal = Mock(return_value="BUY") - + signal = "BUY with confidence 0.85" result = processor.process_signal(signal) assert result == "BUY" @@ -59,22 +61,22 @@ class TestSignalProcessor: """Test processing invalid signal.""" processor = Mock() processor.process_signal = Mock(return_value="HOLD") # Default to HOLD - + result = processor.process_signal("Invalid signal text") assert result == "HOLD" def test_extract_decision_from_text(self): """Test extracting decision from complex text.""" processor = Mock() - + test_cases = [ ("After analysis, I recommend BUY", "BUY"), ("The decision is to SELL immediately", "SELL"), ("Best action: HOLD position", "HOLD"), ("FINAL TRANSACTION PROPOSAL: **BUY**", "BUY"), ] - + for text, expected in test_cases: processor.process_signal = Mock(return_value=expected) result = processor.process_signal(text) - assert result == expected \ No newline at end of file + assert result == expected diff --git a/tests/unit/graph/test_trading_graph.py b/tests/unit/graph/test_trading_graph.py index 3188c336..fda5028d 100644 --- a/tests/unit/graph/test_trading_graph.py +++ b/tests/unit/graph/test_trading_graph.py @@ -5,6 +5,7 @@ from unittest.mock import Mock, mock_open, patch import pytest from tradingagents.graph.trading_graph import TradingAgentsGraph +from .mock_toolkit_fix import patch_toolkit_in_test class TestTradingAgentsGraph: @@ -26,9 +27,8 @@ class TestTradingAgentsGraph: sample_config["project_dir"] = temp_data_dir mock_llm = Mock() mock_chat_openai.return_value = mock_llm - mock_toolkit_instance = Mock() + mock_toolkit_instance = patch_toolkit_in_test(mock_toolkit) mock_toolkit_instance.config = sample_config - mock_toolkit.return_value = mock_toolkit_instance # Execute with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): @@ -56,7 +56,7 @@ class TestTradingAgentsGraph: mock_llm = Mock() mock_chat_openai.return_value = mock_llm mock_toolkit_instance = Mock() - mock_toolkit.return_value = mock_toolkit_instance + patch_toolkit_in_test(mock_toolkit) with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.dataflows.config.set_config"): @@ -79,7 +79,7 @@ class TestTradingAgentsGraph: mock_llm = Mock() mock_chat_anthropic.return_value = mock_llm mock_toolkit_instance = Mock() - mock_toolkit.return_value = mock_toolkit_instance + patch_toolkit_in_test(mock_toolkit) with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.dataflows.config.set_config"): @@ -102,7 +102,7 @@ class TestTradingAgentsGraph: mock_llm = Mock() mock_chat_google.return_value = mock_llm mock_toolkit_instance = Mock() - mock_toolkit.return_value = mock_toolkit_instance + patch_toolkit_in_test(mock_toolkit) with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.dataflows.config.set_config"): @@ -121,7 +121,7 @@ class TestTradingAgentsGraph: sample_config["project_dir"] = temp_data_dir sample_config["llm_provider"] = "unsupported" mock_toolkit_instance = Mock() - mock_toolkit.return_value = mock_toolkit_instance + patch_toolkit_in_test(mock_toolkit) with pytest.raises(ValueError, match="Unsupported LLM provider"): with patch("tradingagents.dataflows.config.set_config"): @@ -147,7 +147,7 @@ class TestTradingAgentsGraph: mock_toolkit_instance.get_YFin_data = Mock() mock_toolkit_instance.get_stockstats_indicators_report_online = Mock() mock_toolkit_instance.get_stockstats_indicators_report = Mock() - mock_toolkit.return_value = mock_toolkit_instance + patch_toolkit_in_test(mock_toolkit) with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.dataflows.config.set_config"): @@ -173,7 +173,7 @@ class TestTradingAgentsGraph: mock_llm = Mock() mock_chat_openai.return_value = mock_llm mock_toolkit_instance = Mock() - mock_toolkit.return_value = mock_toolkit_instance + patch_toolkit_in_test(mock_toolkit) # Mock the graph and its invoke method mock_graph = Mock() @@ -240,7 +240,7 @@ class TestTradingAgentsGraph: mock_llm = Mock() mock_chat_openai.return_value = mock_llm mock_toolkit_instance = Mock() - mock_toolkit.return_value = mock_toolkit_instance + patch_toolkit_in_test(mock_toolkit) # Mock the graph stream method for debug mode mock_graph = Mock() @@ -283,7 +283,7 @@ class TestTradingAgentsGraph: mock_llm = Mock() mock_chat_openai.return_value = mock_llm mock_toolkit_instance = Mock() - mock_toolkit.return_value = mock_toolkit_instance + patch_toolkit_in_test(mock_toolkit) with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.dataflows.config.set_config"): @@ -342,7 +342,7 @@ class TestTradingAgentsGraph: mock_llm = Mock() mock_chat_openai.return_value = mock_llm mock_toolkit_instance = Mock() - mock_toolkit.return_value = mock_toolkit_instance + patch_toolkit_in_test(mock_toolkit) with ( patch( @@ -388,7 +388,7 @@ class TestTradingAgentsGraph: mock_llm = Mock() mock_chat_openai.return_value = mock_llm mock_toolkit_instance = Mock() - mock_toolkit.return_value = mock_toolkit_instance + patch_toolkit_in_test(mock_toolkit) with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.dataflows.config.set_config"): @@ -425,7 +425,7 @@ class TestTradingAgentsGraph: mock_llm = Mock() mock_chat_openai.return_value = mock_llm mock_toolkit_instance = Mock() - mock_toolkit.return_value = mock_toolkit_instance + patch_toolkit_in_test(mock_toolkit) with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.dataflows.config.set_config"): @@ -447,7 +447,7 @@ class TestTradingAgentsGraphErrorHandling: """Test handling of invalid configuration.""" invalid_config = {"invalid_key": "invalid_value"} mock_toolkit_instance = Mock() - mock_toolkit.return_value = mock_toolkit_instance + patch_toolkit_in_test(mock_toolkit) # This should still work as the class should use defaults for missing keys with patch("tradingagents.dataflows.config.set_config"): @@ -469,7 +469,7 @@ class TestTradingAgentsGraphErrorHandling: mock_llm = Mock() mock_chat_openai.return_value = mock_llm mock_toolkit_instance = Mock() - mock_toolkit.return_value = mock_toolkit_instance + patch_toolkit_in_test(mock_toolkit) # Should handle directory creation gracefully or raise appropriate error with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):