104 lines
3.2 KiB
Python
104 lines
3.2 KiB
Python
#!/usr/bin/env python
|
|
"""Test that mock toolkit fixes work for TradingAgentsGraph."""
|
|
|
|
from unittest.mock import Mock, patch
|
|
import sys
|
|
import os
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from tests.unit.graph.mock_toolkit_fix import create_mock_toolkit_with_tools
|
|
|
|
|
|
def test_mock_toolkit_has_all_methods():
|
|
"""Test that the mock toolkit has all required methods."""
|
|
toolkit = create_mock_toolkit_with_tools()
|
|
|
|
required_methods = [
|
|
"get_YFin_data",
|
|
"get_YFin_data_online",
|
|
"get_stockstats_indicators_report",
|
|
"get_stockstats_indicators_report_online",
|
|
"get_reddit_stock_info",
|
|
"get_stock_news_openai",
|
|
]
|
|
|
|
for method_name in required_methods:
|
|
assert hasattr(toolkit, method_name), f"Missing {method_name}"
|
|
method = getattr(toolkit, method_name)
|
|
assert hasattr(method, "__name__"), f"{method_name} missing __name__"
|
|
assert method.__name__ == method_name, f"{method_name} has wrong __name__"
|
|
assert callable(method), f"{method_name} is not callable"
|
|
|
|
print("✓ Mock toolkit has all required methods with proper attributes")
|
|
return True
|
|
|
|
|
|
def test_tool_node_creation():
|
|
"""Test that ToolNode can be created with mocked toolkit methods."""
|
|
# Mock the ToolNode class
|
|
with patch("langgraph.prebuilt.ToolNode") as MockToolNode:
|
|
MockToolNode.return_value = Mock()
|
|
|
|
toolkit = create_mock_toolkit_with_tools()
|
|
|
|
# Simulate creating tool nodes like in TradingAgentsGraph
|
|
from langgraph.prebuilt import ToolNode
|
|
|
|
tool_node = ToolNode(
|
|
[
|
|
toolkit.get_YFin_data,
|
|
toolkit.get_stockstats_indicators_report,
|
|
]
|
|
)
|
|
|
|
# Should not raise an error
|
|
assert MockToolNode.called
|
|
print("✓ ToolNode can be created with mocked toolkit methods")
|
|
return True
|
|
|
|
|
|
def test_tool_decorator():
|
|
"""Test that @tool decorator works with mocked functions."""
|
|
toolkit = create_mock_toolkit_with_tools()
|
|
|
|
# The @tool decorator expects __name__ attribute
|
|
for attr_name in dir(toolkit):
|
|
if attr_name.startswith("get_"):
|
|
method = getattr(toolkit, attr_name)
|
|
assert hasattr(method, "__name__"), f"{attr_name} missing __name__"
|
|
|
|
print("✓ All toolkit methods are compatible with @tool decorator")
|
|
return True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("Testing mock toolkit fixes for TradingAgentsGraph...")
|
|
print("-" * 50)
|
|
|
|
tests = [
|
|
test_mock_toolkit_has_all_methods,
|
|
test_tool_node_creation,
|
|
test_tool_decorator,
|
|
]
|
|
|
|
all_passed = True
|
|
for test in tests:
|
|
try:
|
|
if not test():
|
|
all_passed = False
|
|
print(f"✗ {test.__name__} failed")
|
|
except Exception as e:
|
|
all_passed = False
|
|
print(f"✗ {test.__name__} raised exception: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
|
|
print("-" * 50)
|
|
if all_passed:
|
|
print("✅ All tests passed! TradingAgentsGraph mock fixes are working.")
|
|
else:
|
|
print("❌ Some tests failed. Check the output above.")
|