diff --git a/test_graph_fix.py b/test_graph_fix.py index da1a60ba..c7fe8a8d 100644 --- a/test_graph_fix.py +++ b/test_graph_fix.py @@ -14,7 +14,7 @@ from tests.unit.graph.mock_toolkit_fix import create_mock_toolkit_with_tools def test_mock_toolkit_has_all_methods(): """Test that the mock toolkit has all required methods.""" toolkit = create_mock_toolkit_with_tools() - + required_methods = [ "get_YFin_data", "get_YFin_data_online", @@ -23,14 +23,14 @@ def test_mock_toolkit_has_all_methods(): "get_reddit_stock_info", "get_stock_news_openai", ] - + for method_name in required_methods: assert hasattr(toolkit, method_name), f"Missing {method_name}" method = getattr(toolkit, method_name) - assert hasattr(method, '__name__'), f"{method_name} missing __name__" + assert hasattr(method, "__name__"), f"{method_name} missing __name__" assert method.__name__ == method_name, f"{method_name} has wrong __name__" assert callable(method), f"{method_name} is not callable" - + print("✓ Mock toolkit has all required methods with proper attributes") return True @@ -40,17 +40,19 @@ def test_tool_node_creation(): # Mock the ToolNode class with patch("langgraph.prebuilt.ToolNode") as MockToolNode: MockToolNode.return_value = Mock() - + toolkit = create_mock_toolkit_with_tools() - + # Simulate creating tool nodes like in TradingAgentsGraph from langgraph.prebuilt import ToolNode - - tool_node = ToolNode([ - toolkit.get_YFin_data, - toolkit.get_stockstats_indicators_report, - ]) - + + tool_node = ToolNode( + [ + toolkit.get_YFin_data, + toolkit.get_stockstats_indicators_report, + ] + ) + # Should not raise an error assert MockToolNode.called print("✓ ToolNode can be created with mocked toolkit methods") @@ -60,13 +62,13 @@ def test_tool_node_creation(): def test_tool_decorator(): """Test that @tool decorator works with mocked functions.""" toolkit = create_mock_toolkit_with_tools() - + # The @tool decorator expects __name__ attribute for attr_name in dir(toolkit): - if attr_name.startswith('get_'): + if attr_name.startswith("get_"): method = getattr(toolkit, attr_name) - assert hasattr(method, '__name__'), f"{attr_name} missing __name__" - + assert hasattr(method, "__name__"), f"{attr_name} missing __name__" + print("✓ All toolkit methods are compatible with @tool decorator") return True @@ -74,13 +76,13 @@ def test_tool_decorator(): if __name__ == "__main__": print("Testing mock toolkit fixes for TradingAgentsGraph...") print("-" * 50) - + tests = [ test_mock_toolkit_has_all_methods, test_tool_node_creation, test_tool_decorator, ] - + all_passed = True for test in tests: try: @@ -91,10 +93,11 @@ if __name__ == "__main__": all_passed = False print(f"✗ {test.__name__} raised exception: {e}") import traceback + traceback.print_exc() - + print("-" * 50) if all_passed: print("✅ All tests passed! TradingAgentsGraph mock fixes are working.") else: - print("❌ Some tests failed. Check the output above.") \ No newline at end of file + print("❌ Some tests failed. Check the output above.") diff --git a/tests/unit/graph/mock_toolkit_fix.py b/tests/unit/graph/mock_toolkit_fix.py index 9e5a3005..af756b4b 100644 --- a/tests/unit/graph/mock_toolkit_fix.py +++ b/tests/unit/graph/mock_toolkit_fix.py @@ -7,7 +7,7 @@ def create_mock_toolkit_with_tools(): """Create a mock toolkit with all necessary tool methods.""" toolkit = Mock() toolkit.config = {"online_tools": False} - + # List of all methods that need to be mocked tool_methods = [ # Market tools @@ -25,25 +25,25 @@ def create_mock_toolkit_with_tools(): "get_reddit_news", # Fundamentals tools "get_simfin_cashflow", - "get_simfin_income_stmt", + "get_simfin_income_stmt", "get_simfin_balance_sheet", "get_finnhub_basic_financials", ] - + # Create mock for each method with proper __name__ attribute for method_name in tool_methods: # Create a function with the right name def mock_func(): return f"Mock {method_name} data" - + # Create Mock wrapping the function mock_method = Mock(side_effect=mock_func) mock_method.__name__ = method_name mock_method.name = method_name - + # Set it on the toolkit setattr(toolkit, method_name, mock_method) - + return toolkit @@ -51,4 +51,4 @@ def patch_toolkit_in_test(mock_toolkit): """Configure the mock_toolkit patch to return a properly mocked instance.""" mock_instance = create_mock_toolkit_with_tools() mock_toolkit.return_value = mock_instance - return mock_instance \ No newline at end of file + return mock_instance