Fix CI/CD test failures
- Apply Black formatting to all test files - Fix Mock objects to include tool_calls attribute for len() checks - Add proper __name__ attributes to mock toolkit methods for @tool decorator - Create mock_toolkit_fix helper for consistent toolkit mocking All tests should now pass with proper mocking setup.
This commit is contained in:
parent
ba958c20e5
commit
0ab8c8fc46
|
|
@ -0,0 +1,100 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
"""Test that mock toolkit fixes work for TradingAgentsGraph."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from tests.unit.graph.mock_toolkit_fix import create_mock_toolkit_with_tools
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_toolkit_has_all_methods():
|
||||||
|
"""Test that the mock toolkit has all required methods."""
|
||||||
|
toolkit = create_mock_toolkit_with_tools()
|
||||||
|
|
||||||
|
required_methods = [
|
||||||
|
"get_YFin_data",
|
||||||
|
"get_YFin_data_online",
|
||||||
|
"get_stockstats_indicators_report",
|
||||||
|
"get_stockstats_indicators_report_online",
|
||||||
|
"get_reddit_stock_info",
|
||||||
|
"get_stock_news_openai",
|
||||||
|
]
|
||||||
|
|
||||||
|
for method_name in required_methods:
|
||||||
|
assert hasattr(toolkit, method_name), f"Missing {method_name}"
|
||||||
|
method = getattr(toolkit, method_name)
|
||||||
|
assert hasattr(method, '__name__'), f"{method_name} missing __name__"
|
||||||
|
assert method.__name__ == method_name, f"{method_name} has wrong __name__"
|
||||||
|
assert callable(method), f"{method_name} is not callable"
|
||||||
|
|
||||||
|
print("✓ Mock toolkit has all required methods with proper attributes")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_node_creation():
|
||||||
|
"""Test that ToolNode can be created with mocked toolkit methods."""
|
||||||
|
# Mock the ToolNode class
|
||||||
|
with patch("langgraph.prebuilt.ToolNode") as MockToolNode:
|
||||||
|
MockToolNode.return_value = Mock()
|
||||||
|
|
||||||
|
toolkit = create_mock_toolkit_with_tools()
|
||||||
|
|
||||||
|
# Simulate creating tool nodes like in TradingAgentsGraph
|
||||||
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
|
tool_node = ToolNode([
|
||||||
|
toolkit.get_YFin_data,
|
||||||
|
toolkit.get_stockstats_indicators_report,
|
||||||
|
])
|
||||||
|
|
||||||
|
# Should not raise an error
|
||||||
|
assert MockToolNode.called
|
||||||
|
print("✓ ToolNode can be created with mocked toolkit methods")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_decorator():
|
||||||
|
"""Test that @tool decorator works with mocked functions."""
|
||||||
|
toolkit = create_mock_toolkit_with_tools()
|
||||||
|
|
||||||
|
# The @tool decorator expects __name__ attribute
|
||||||
|
for attr_name in dir(toolkit):
|
||||||
|
if attr_name.startswith('get_'):
|
||||||
|
method = getattr(toolkit, attr_name)
|
||||||
|
assert hasattr(method, '__name__'), f"{attr_name} missing __name__"
|
||||||
|
|
||||||
|
print("✓ All toolkit methods are compatible with @tool decorator")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Testing mock toolkit fixes for TradingAgentsGraph...")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
test_mock_toolkit_has_all_methods,
|
||||||
|
test_tool_node_creation,
|
||||||
|
test_tool_decorator,
|
||||||
|
]
|
||||||
|
|
||||||
|
all_passed = True
|
||||||
|
for test in tests:
|
||||||
|
try:
|
||||||
|
if not test():
|
||||||
|
all_passed = False
|
||||||
|
print(f"✗ {test.__name__} failed")
|
||||||
|
except Exception as e:
|
||||||
|
all_passed = False
|
||||||
|
print(f"✗ {test.__name__} raised exception: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
print("-" * 50)
|
||||||
|
if all_passed:
|
||||||
|
print("✅ All tests passed! TradingAgentsGraph mock fixes are working.")
|
||||||
|
else:
|
||||||
|
print("❌ Some tests failed. Check the output above.")
|
||||||
|
|
@ -31,7 +31,13 @@ def mock_llm():
|
||||||
"""Mock LLM for testing."""
|
"""Mock LLM for testing."""
|
||||||
mock = Mock()
|
mock = Mock()
|
||||||
mock.model_name = "test-model"
|
mock.model_name = "test-model"
|
||||||
mock.invoke.return_value = Mock(content="Test response")
|
|
||||||
|
# Create a mock result with tool_calls attribute
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.content = "Test response"
|
||||||
|
mock_result.tool_calls = [] # Add tool_calls attribute for len() check
|
||||||
|
|
||||||
|
mock.invoke.return_value = mock_result
|
||||||
mock.bind_tools.return_value = mock
|
mock.bind_tools.return_value = mock
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
|
|
@ -154,7 +160,9 @@ def mock_toolkit():
|
||||||
side_effect=mock_get_stockstats_indicators_report
|
side_effect=mock_get_stockstats_indicators_report
|
||||||
)
|
)
|
||||||
toolkit.get_stockstats_indicators_report.name = "get_stockstats_indicators_report"
|
toolkit.get_stockstats_indicators_report.name = "get_stockstats_indicators_report"
|
||||||
toolkit.get_stockstats_indicators_report.__name__ = "get_stockstats_indicators_report"
|
toolkit.get_stockstats_indicators_report.__name__ = (
|
||||||
|
"get_stockstats_indicators_report"
|
||||||
|
)
|
||||||
|
|
||||||
toolkit.get_stockstats_indicators_report_online = Mock(
|
toolkit.get_stockstats_indicators_report_online = Mock(
|
||||||
side_effect=mock_get_stockstats_indicators_report_online
|
side_effect=mock_get_stockstats_indicators_report_online
|
||||||
|
|
|
||||||
|
|
@ -12,73 +12,87 @@ class TestMarketAnalystExtended:
|
||||||
"""Extended mock LLM with more functionality."""
|
"""Extended mock LLM with more functionality."""
|
||||||
mock = Mock()
|
mock = Mock()
|
||||||
mock.model_name = "test-model"
|
mock.model_name = "test-model"
|
||||||
|
|
||||||
# Create a mock chain
|
# Create a mock chain
|
||||||
mock_chain = Mock()
|
mock_chain = Mock()
|
||||||
mock_chain.invoke = Mock()
|
mock_chain.invoke = Mock()
|
||||||
mock.bind_tools = Mock(return_value=mock_chain)
|
mock.bind_tools = Mock(return_value=mock_chain)
|
||||||
|
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_toolkit_extended(self):
|
def mock_toolkit_extended(self):
|
||||||
"""Extended mock toolkit with all methods."""
|
"""Extended mock toolkit with all methods."""
|
||||||
toolkit = Mock()
|
toolkit = Mock()
|
||||||
toolkit.config = {"online_tools": False}
|
toolkit.config = {"online_tools": False}
|
||||||
|
|
||||||
# Create mock functions with proper attributes
|
# Create mock functions with proper attributes
|
||||||
def mock_yfin():
|
def mock_yfin():
|
||||||
return "YFin data"
|
return "YFin data"
|
||||||
|
|
||||||
def mock_stockstats():
|
def mock_stockstats():
|
||||||
return "Stockstats data"
|
return "Stockstats data"
|
||||||
|
|
||||||
toolkit.get_YFin_data = Mock(side_effect=mock_yfin)
|
toolkit.get_YFin_data = Mock(side_effect=mock_yfin)
|
||||||
toolkit.get_YFin_data.__name__ = "get_YFin_data"
|
toolkit.get_YFin_data.__name__ = "get_YFin_data"
|
||||||
toolkit.get_YFin_data.name = "get_YFin_data"
|
toolkit.get_YFin_data.name = "get_YFin_data"
|
||||||
|
|
||||||
toolkit.get_stockstats_indicators_report = Mock(side_effect=mock_stockstats)
|
toolkit.get_stockstats_indicators_report = Mock(side_effect=mock_stockstats)
|
||||||
toolkit.get_stockstats_indicators_report.__name__ = "get_stockstats_indicators_report"
|
toolkit.get_stockstats_indicators_report.__name__ = (
|
||||||
toolkit.get_stockstats_indicators_report.name = "get_stockstats_indicators_report"
|
"get_stockstats_indicators_report"
|
||||||
|
)
|
||||||
|
toolkit.get_stockstats_indicators_report.name = (
|
||||||
|
"get_stockstats_indicators_report"
|
||||||
|
)
|
||||||
|
|
||||||
# Online versions
|
# Online versions
|
||||||
toolkit.get_YFin_data_online = Mock(side_effect=mock_yfin)
|
toolkit.get_YFin_data_online = Mock(side_effect=mock_yfin)
|
||||||
toolkit.get_YFin_data_online.__name__ = "get_YFin_data_online"
|
toolkit.get_YFin_data_online.__name__ = "get_YFin_data_online"
|
||||||
toolkit.get_YFin_data_online.name = "get_YFin_data_online"
|
toolkit.get_YFin_data_online.name = "get_YFin_data_online"
|
||||||
|
|
||||||
toolkit.get_stockstats_indicators_report_online = Mock(side_effect=mock_stockstats)
|
toolkit.get_stockstats_indicators_report_online = Mock(
|
||||||
toolkit.get_stockstats_indicators_report_online.__name__ = "get_stockstats_indicators_report_online"
|
side_effect=mock_stockstats
|
||||||
toolkit.get_stockstats_indicators_report_online.name = "get_stockstats_indicators_report_online"
|
)
|
||||||
|
toolkit.get_stockstats_indicators_report_online.__name__ = (
|
||||||
|
"get_stockstats_indicators_report_online"
|
||||||
|
)
|
||||||
|
toolkit.get_stockstats_indicators_report_online.name = (
|
||||||
|
"get_stockstats_indicators_report_online"
|
||||||
|
)
|
||||||
|
|
||||||
return toolkit
|
return toolkit
|
||||||
|
|
||||||
def test_market_analyst_system_message(self, mock_llm_extended, mock_toolkit_extended):
|
def test_market_analyst_system_message(
|
||||||
|
self, mock_llm_extended, mock_toolkit_extended
|
||||||
|
):
|
||||||
"""Test that system message is properly formatted."""
|
"""Test that system message is properly formatted."""
|
||||||
# This would normally import and test the actual function
|
# This would normally import and test the actual function
|
||||||
# For now, we test the mock behavior
|
# For now, we test the mock behavior
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
"company_of_interest": "AAPL",
|
"company_of_interest": "AAPL",
|
||||||
"trade_date": "2024-05-10",
|
"trade_date": "2024-05-10",
|
||||||
"messages": []
|
"messages": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Simulate creating analyst
|
# Simulate creating analyst
|
||||||
mock_analyst = Mock()
|
mock_analyst = Mock()
|
||||||
mock_analyst.return_value = {"messages": [], "market_report": "Test report"}
|
mock_analyst.return_value = {"messages": [], "market_report": "Test report"}
|
||||||
|
|
||||||
result = mock_analyst(state)
|
result = mock_analyst(state)
|
||||||
assert "market_report" in result
|
assert "market_report" in result
|
||||||
assert "messages" in result
|
assert "messages" in result
|
||||||
|
|
||||||
def test_market_analyst_with_multiple_indicators(self, mock_llm_extended, mock_toolkit_extended):
|
def test_market_analyst_with_multiple_indicators(
|
||||||
|
self, mock_llm_extended, mock_toolkit_extended
|
||||||
|
):
|
||||||
"""Test analyst with multiple technical indicators."""
|
"""Test analyst with multiple technical indicators."""
|
||||||
state = {
|
state = {
|
||||||
"company_of_interest": "TSLA",
|
"company_of_interest": "TSLA",
|
||||||
"trade_date": "2024-05-15",
|
"trade_date": "2024-05-15",
|
||||||
"messages": []
|
"messages": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock result with multiple indicators
|
# Mock result with multiple indicators
|
||||||
mock_result = Mock()
|
mock_result = Mock()
|
||||||
mock_result.content = """
|
mock_result.content = """
|
||||||
|
|
@ -90,95 +104,93 @@ class TestMarketAnalystExtended:
|
||||||
- Volume: Above average
|
- Volume: Above average
|
||||||
"""
|
"""
|
||||||
mock_result.tool_calls = []
|
mock_result.tool_calls = []
|
||||||
|
|
||||||
mock_llm_extended.bind_tools.return_value.invoke.return_value = mock_result
|
mock_llm_extended.bind_tools.return_value.invoke.return_value = mock_result
|
||||||
|
|
||||||
# Create mock analyst function
|
# Create mock analyst function
|
||||||
def mock_analyst(state):
|
def mock_analyst(state):
|
||||||
return {
|
return {"messages": [mock_result], "market_report": mock_result.content}
|
||||||
"messages": [mock_result],
|
|
||||||
"market_report": mock_result.content
|
|
||||||
}
|
|
||||||
|
|
||||||
result = mock_analyst(state)
|
result = mock_analyst(state)
|
||||||
assert "RSI" in result["market_report"]
|
assert "RSI" in result["market_report"]
|
||||||
assert "MACD" in result["market_report"]
|
assert "MACD" in result["market_report"]
|
||||||
assert "Bollinger" in result["market_report"]
|
assert "Bollinger" in result["market_report"]
|
||||||
|
|
||||||
def test_market_analyst_error_handling(self, mock_llm_extended, mock_toolkit_extended):
|
def test_market_analyst_error_handling(
|
||||||
|
self, mock_llm_extended, mock_toolkit_extended
|
||||||
|
):
|
||||||
"""Test error handling in market analyst."""
|
"""Test error handling in market analyst."""
|
||||||
state = {
|
state = {
|
||||||
"company_of_interest": "INVALID",
|
"company_of_interest": "INVALID",
|
||||||
"trade_date": "2024-05-10",
|
"trade_date": "2024-05-10",
|
||||||
"messages": []
|
"messages": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock error scenario
|
# Mock error scenario
|
||||||
mock_llm_extended.bind_tools.return_value.invoke.side_effect = Exception("API Error")
|
mock_llm_extended.bind_tools.return_value.invoke.side_effect = Exception(
|
||||||
|
"API Error"
|
||||||
|
)
|
||||||
|
|
||||||
# Create analyst with error handling
|
# Create analyst with error handling
|
||||||
def mock_analyst_with_error_handling(state):
|
def mock_analyst_with_error_handling(state):
|
||||||
try:
|
try:
|
||||||
# Would call actual analyst here
|
# Would call actual analyst here
|
||||||
raise Exception("API Error")
|
raise Exception("API Error")
|
||||||
except Exception:
|
except Exception:
|
||||||
return {
|
return {"messages": [], "market_report": "Error analyzing market data"}
|
||||||
"messages": [],
|
|
||||||
"market_report": "Error analyzing market data"
|
|
||||||
}
|
|
||||||
|
|
||||||
result = mock_analyst_with_error_handling(state)
|
result = mock_analyst_with_error_handling(state)
|
||||||
assert result["market_report"] == "Error analyzing market data"
|
assert result["market_report"] == "Error analyzing market data"
|
||||||
|
|
||||||
def test_market_analyst_date_formatting(self, mock_llm_extended, mock_toolkit_extended):
|
def test_market_analyst_date_formatting(
|
||||||
|
self, mock_llm_extended, mock_toolkit_extended
|
||||||
|
):
|
||||||
"""Test various date formats in market analyst."""
|
"""Test various date formats in market analyst."""
|
||||||
test_dates = [
|
test_dates = [
|
||||||
"2024-01-01",
|
"2024-01-01",
|
||||||
"2024-12-31",
|
"2024-12-31",
|
||||||
"2024-05-15",
|
"2024-05-15",
|
||||||
]
|
]
|
||||||
|
|
||||||
for date in test_dates:
|
for date in test_dates:
|
||||||
state = {
|
state = {"company_of_interest": "AAPL", "trade_date": date, "messages": []}
|
||||||
"company_of_interest": "AAPL",
|
|
||||||
"trade_date": date,
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_result = Mock()
|
mock_result = Mock()
|
||||||
mock_result.content = f"Analysis for {date}"
|
mock_result.content = f"Analysis for {date}"
|
||||||
mock_result.tool_calls = []
|
mock_result.tool_calls = []
|
||||||
|
|
||||||
def mock_analyst(state):
|
def mock_analyst(state):
|
||||||
return {
|
return {
|
||||||
"messages": [mock_result],
|
"messages": [mock_result],
|
||||||
"market_report": f"Analysis for {state['trade_date']}"
|
"market_report": f"Analysis for {state['trade_date']}",
|
||||||
}
|
}
|
||||||
|
|
||||||
result = mock_analyst(state)
|
result = mock_analyst(state)
|
||||||
assert date in result["market_report"]
|
assert date in result["market_report"]
|
||||||
|
|
||||||
def test_market_analyst_ticker_variations(self, mock_llm_extended, mock_toolkit_extended):
|
def test_market_analyst_ticker_variations(
|
||||||
|
self, mock_llm_extended, mock_toolkit_extended
|
||||||
|
):
|
||||||
"""Test analyst with various ticker symbols."""
|
"""Test analyst with various ticker symbols."""
|
||||||
tickers = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA"]
|
tickers = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA"]
|
||||||
|
|
||||||
for ticker in tickers:
|
for ticker in tickers:
|
||||||
state = {
|
state = {
|
||||||
"company_of_interest": ticker,
|
"company_of_interest": ticker,
|
||||||
"trade_date": "2024-05-10",
|
"trade_date": "2024-05-10",
|
||||||
"messages": []
|
"messages": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_result = Mock()
|
mock_result = Mock()
|
||||||
mock_result.content = f"Analysis for {ticker}"
|
mock_result.content = f"Analysis for {ticker}"
|
||||||
mock_result.tool_calls = []
|
mock_result.tool_calls = []
|
||||||
|
|
||||||
def mock_analyst(state):
|
def mock_analyst(state):
|
||||||
return {
|
return {
|
||||||
"messages": [mock_result],
|
"messages": [mock_result],
|
||||||
"market_report": f"Analysis for {state['company_of_interest']}"
|
"market_report": f"Analysis for {state['company_of_interest']}",
|
||||||
}
|
}
|
||||||
|
|
||||||
result = mock_analyst(state)
|
result = mock_analyst(state)
|
||||||
assert ticker in result["market_report"]
|
assert ticker in result["market_report"]
|
||||||
|
|
||||||
|
|
@ -187,23 +199,23 @@ class TestMarketAnalystExtended:
|
||||||
# Test offline configuration
|
# Test offline configuration
|
||||||
toolkit_offline = Mock()
|
toolkit_offline = Mock()
|
||||||
toolkit_offline.config = {"online_tools": False}
|
toolkit_offline.config = {"online_tools": False}
|
||||||
|
|
||||||
def mock_offline():
|
def mock_offline():
|
||||||
return "Offline data"
|
return "Offline data"
|
||||||
|
|
||||||
toolkit_offline.get_YFin_data = Mock(side_effect=mock_offline)
|
toolkit_offline.get_YFin_data = Mock(side_effect=mock_offline)
|
||||||
toolkit_offline.get_YFin_data.__name__ = "get_YFin_data"
|
toolkit_offline.get_YFin_data.__name__ = "get_YFin_data"
|
||||||
|
|
||||||
# Test online configuration
|
# Test online configuration
|
||||||
toolkit_online = Mock()
|
toolkit_online = Mock()
|
||||||
toolkit_online.config = {"online_tools": True}
|
toolkit_online.config = {"online_tools": True}
|
||||||
|
|
||||||
def mock_online():
|
def mock_online():
|
||||||
return "Online data"
|
return "Online data"
|
||||||
|
|
||||||
toolkit_online.get_YFin_data_online = Mock(side_effect=mock_online)
|
toolkit_online.get_YFin_data_online = Mock(side_effect=mock_online)
|
||||||
toolkit_online.get_YFin_data_online.__name__ = "get_YFin_data_online"
|
toolkit_online.get_YFin_data_online.__name__ = "get_YFin_data_online"
|
||||||
|
|
||||||
# Both should work correctly
|
# Both should work correctly
|
||||||
assert toolkit_offline.config["online_tools"] is False
|
assert toolkit_offline.config["online_tools"] is False
|
||||||
assert toolkit_online.config["online_tools"] is True
|
assert toolkit_online.config["online_tools"] is True
|
||||||
|
|
@ -212,59 +224,48 @@ class TestMarketAnalystExtended:
|
||||||
|
|
||||||
def test_market_analyst_empty_state(self, mock_llm_extended, mock_toolkit_extended):
|
def test_market_analyst_empty_state(self, mock_llm_extended, mock_toolkit_extended):
|
||||||
"""Test analyst with minimal/empty state."""
|
"""Test analyst with minimal/empty state."""
|
||||||
state = {
|
state = {"company_of_interest": "", "trade_date": "", "messages": []}
|
||||||
"company_of_interest": "",
|
|
||||||
"trade_date": "",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_result = Mock()
|
mock_result = Mock()
|
||||||
mock_result.content = "No data available"
|
mock_result.content = "No data available"
|
||||||
mock_result.tool_calls = []
|
mock_result.tool_calls = []
|
||||||
|
|
||||||
def mock_analyst(state):
|
def mock_analyst(state):
|
||||||
if not state["company_of_interest"] or not state["trade_date"]:
|
if not state["company_of_interest"] or not state["trade_date"]:
|
||||||
return {
|
return {"messages": [], "market_report": "No data available"}
|
||||||
"messages": [],
|
return {"messages": [mock_result], "market_report": mock_result.content}
|
||||||
"market_report": "No data available"
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
"messages": [mock_result],
|
|
||||||
"market_report": mock_result.content
|
|
||||||
}
|
|
||||||
|
|
||||||
result = mock_analyst(state)
|
result = mock_analyst(state)
|
||||||
assert result["market_report"] == "No data available"
|
assert result["market_report"] == "No data available"
|
||||||
|
|
||||||
def test_market_analyst_tool_calls_tracking(self, mock_llm_extended, mock_toolkit_extended):
|
def test_market_analyst_tool_calls_tracking(
|
||||||
|
self, mock_llm_extended, mock_toolkit_extended
|
||||||
|
):
|
||||||
"""Test tracking of tool calls in market analyst."""
|
"""Test tracking of tool calls in market analyst."""
|
||||||
state = {
|
state = {
|
||||||
"company_of_interest": "AAPL",
|
"company_of_interest": "AAPL",
|
||||||
"trade_date": "2024-05-10",
|
"trade_date": "2024-05-10",
|
||||||
"messages": []
|
"messages": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock result with tool calls
|
# Mock result with tool calls
|
||||||
mock_tool_call = Mock()
|
mock_tool_call = Mock()
|
||||||
mock_tool_call.function.name = "get_YFin_data"
|
mock_tool_call.function.name = "get_YFin_data"
|
||||||
mock_tool_call.function.arguments = '{"ticker": "AAPL"}'
|
mock_tool_call.function.arguments = '{"ticker": "AAPL"}'
|
||||||
|
|
||||||
mock_result = Mock()
|
mock_result = Mock()
|
||||||
mock_result.content = ""
|
mock_result.content = ""
|
||||||
mock_result.tool_calls = [mock_tool_call]
|
mock_result.tool_calls = [mock_tool_call]
|
||||||
|
|
||||||
mock_llm_extended.bind_tools.return_value.invoke.return_value = mock_result
|
mock_llm_extended.bind_tools.return_value.invoke.return_value = mock_result
|
||||||
|
|
||||||
def mock_analyst(state):
|
def mock_analyst(state):
|
||||||
result = mock_llm_extended.bind_tools([]).invoke(state["messages"])
|
result = mock_llm_extended.bind_tools([]).invoke(state["messages"])
|
||||||
# When tool_calls exist, market_report should be empty
|
# When tool_calls exist, market_report should be empty
|
||||||
report = "" if result.tool_calls else result.content
|
report = "" if result.tool_calls else result.content
|
||||||
return {
|
return {"messages": [result], "market_report": report}
|
||||||
"messages": [result],
|
|
||||||
"market_report": report
|
|
||||||
}
|
|
||||||
|
|
||||||
result = mock_analyst(state)
|
result = mock_analyst(state)
|
||||||
assert result["market_report"] == "" # Empty when tool calls exist
|
assert result["market_report"] == "" # Empty when tool calls exist
|
||||||
assert len(result["messages"]) == 1
|
assert len(result["messages"]) == 1
|
||||||
assert result["messages"][0].tool_calls == [mock_tool_call]
|
assert result["messages"][0].tool_calls == [mock_tool_call]
|
||||||
|
|
|
||||||
|
|
@ -12,4 +12,4 @@ class TestDataflowsUtils:
|
||||||
"""Placeholder test to ensure test file is valid."""
|
"""Placeholder test to ensure test file is valid."""
|
||||||
assert True
|
assert True
|
||||||
|
|
||||||
# Add more tests here as needed for utils.py functions
|
# Add more tests here as needed for utils.py functions
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
"""Helper to create properly mocked toolkit for test_trading_graph."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_toolkit_with_tools():
|
||||||
|
"""Create a mock toolkit with all necessary tool methods."""
|
||||||
|
toolkit = Mock()
|
||||||
|
toolkit.config = {"online_tools": False}
|
||||||
|
|
||||||
|
# List of all methods that need to be mocked
|
||||||
|
tool_methods = [
|
||||||
|
# Market tools
|
||||||
|
"get_YFin_data",
|
||||||
|
"get_YFin_data_online",
|
||||||
|
"get_stockstats_indicators_report",
|
||||||
|
"get_stockstats_indicators_report_online",
|
||||||
|
# Social tools
|
||||||
|
"get_reddit_stock_info",
|
||||||
|
"get_stock_news_openai",
|
||||||
|
# News tools
|
||||||
|
"get_global_news_openai",
|
||||||
|
"get_google_news",
|
||||||
|
"get_finnhub_news",
|
||||||
|
"get_reddit_news",
|
||||||
|
# Fundamentals tools
|
||||||
|
"get_simfin_cashflow",
|
||||||
|
"get_simfin_income_stmt",
|
||||||
|
"get_simfin_balance_sheet",
|
||||||
|
"get_finnhub_basic_financials",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create mock for each method with proper __name__ attribute
|
||||||
|
for method_name in tool_methods:
|
||||||
|
# Create a function with the right name
|
||||||
|
def mock_func():
|
||||||
|
return f"Mock {method_name} data"
|
||||||
|
|
||||||
|
# Create Mock wrapping the function
|
||||||
|
mock_method = Mock(side_effect=mock_func)
|
||||||
|
mock_method.__name__ = method_name
|
||||||
|
mock_method.name = method_name
|
||||||
|
|
||||||
|
# Set it on the toolkit
|
||||||
|
setattr(toolkit, method_name, mock_method)
|
||||||
|
|
||||||
|
return toolkit
|
||||||
|
|
||||||
|
|
||||||
|
def patch_toolkit_in_test(mock_toolkit):
|
||||||
|
"""Configure the mock_toolkit patch to return a properly mocked instance."""
|
||||||
|
mock_instance = create_mock_toolkit_with_tools()
|
||||||
|
mock_toolkit.return_value = mock_instance
|
||||||
|
return mock_instance
|
||||||
|
|
@ -13,16 +13,16 @@ class TestPropagator:
|
||||||
propagator = Mock()
|
propagator = Mock()
|
||||||
propagator.create_initial_state = Mock()
|
propagator.create_initial_state = Mock()
|
||||||
propagator.get_graph_args = Mock()
|
propagator.get_graph_args = Mock()
|
||||||
|
|
||||||
assert hasattr(propagator, 'create_initial_state')
|
assert hasattr(propagator, "create_initial_state")
|
||||||
assert hasattr(propagator, 'get_graph_args')
|
assert hasattr(propagator, "get_graph_args")
|
||||||
assert callable(propagator.create_initial_state)
|
assert callable(propagator.create_initial_state)
|
||||||
assert callable(propagator.get_graph_args)
|
assert callable(propagator.get_graph_args)
|
||||||
|
|
||||||
def test_create_initial_state(self):
|
def test_create_initial_state(self):
|
||||||
"""Test creating initial state for propagation."""
|
"""Test creating initial state for propagation."""
|
||||||
propagator = Mock()
|
propagator = Mock()
|
||||||
|
|
||||||
# Mock the create_initial_state method
|
# Mock the create_initial_state method
|
||||||
expected_state = {
|
expected_state = {
|
||||||
"company_of_interest": "AAPL",
|
"company_of_interest": "AAPL",
|
||||||
|
|
@ -50,12 +50,12 @@ class TestPropagator:
|
||||||
"investment_plan": "",
|
"investment_plan": "",
|
||||||
"final_trade_decision": "",
|
"final_trade_decision": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
propagator.create_initial_state = Mock(return_value=expected_state)
|
propagator.create_initial_state = Mock(return_value=expected_state)
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
state = propagator.create_initial_state("AAPL", "2024-05-10")
|
state = propagator.create_initial_state("AAPL", "2024-05-10")
|
||||||
|
|
||||||
assert state["company_of_interest"] == "AAPL"
|
assert state["company_of_interest"] == "AAPL"
|
||||||
assert state["trade_date"] == "2024-05-10"
|
assert state["trade_date"] == "2024-05-10"
|
||||||
assert state["messages"] == []
|
assert state["messages"] == []
|
||||||
|
|
@ -66,18 +66,18 @@ class TestPropagator:
|
||||||
def test_get_graph_args(self):
|
def test_get_graph_args(self):
|
||||||
"""Test getting graph arguments."""
|
"""Test getting graph arguments."""
|
||||||
propagator = Mock()
|
propagator = Mock()
|
||||||
|
|
||||||
# Mock the get_graph_args method
|
# Mock the get_graph_args method
|
||||||
expected_args = {
|
expected_args = {
|
||||||
"recursion_limit": 100,
|
"recursion_limit": 100,
|
||||||
"config": {"tags": ["tradingagents"]},
|
"config": {"tags": ["tradingagents"]},
|
||||||
}
|
}
|
||||||
|
|
||||||
propagator.get_graph_args = Mock(return_value=expected_args)
|
propagator.get_graph_args = Mock(return_value=expected_args)
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
args = propagator.get_graph_args()
|
args = propagator.get_graph_args()
|
||||||
|
|
||||||
assert "recursion_limit" in args
|
assert "recursion_limit" in args
|
||||||
assert "config" in args
|
assert "config" in args
|
||||||
assert args["recursion_limit"] == 100
|
assert args["recursion_limit"] == 100
|
||||||
|
|
@ -86,63 +86,56 @@ class TestPropagator:
|
||||||
def test_propagate_with_different_tickers(self):
|
def test_propagate_with_different_tickers(self):
|
||||||
"""Test propagation with different ticker symbols."""
|
"""Test propagation with different ticker symbols."""
|
||||||
propagator = Mock()
|
propagator = Mock()
|
||||||
|
|
||||||
tickers = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA"]
|
tickers = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA"]
|
||||||
|
|
||||||
for ticker in tickers:
|
for ticker in tickers:
|
||||||
state = {
|
state = {
|
||||||
"company_of_interest": ticker,
|
"company_of_interest": ticker,
|
||||||
"trade_date": "2024-05-10",
|
"trade_date": "2024-05-10",
|
||||||
"messages": []
|
"messages": [],
|
||||||
}
|
}
|
||||||
propagator.create_initial_state = Mock(return_value=state)
|
propagator.create_initial_state = Mock(return_value=state)
|
||||||
|
|
||||||
result = propagator.create_initial_state(ticker, "2024-05-10")
|
result = propagator.create_initial_state(ticker, "2024-05-10")
|
||||||
assert result["company_of_interest"] == ticker
|
assert result["company_of_interest"] == ticker
|
||||||
|
|
||||||
def test_propagate_with_different_dates(self):
|
def test_propagate_with_different_dates(self):
|
||||||
"""Test propagation with different dates."""
|
"""Test propagation with different dates."""
|
||||||
propagator = Mock()
|
propagator = Mock()
|
||||||
|
|
||||||
dates = ["2024-01-01", "2024-06-15", "2024-12-31"]
|
dates = ["2024-01-01", "2024-06-15", "2024-12-31"]
|
||||||
|
|
||||||
for date in dates:
|
for date in dates:
|
||||||
state = {
|
state = {"company_of_interest": "AAPL", "trade_date": date, "messages": []}
|
||||||
"company_of_interest": "AAPL",
|
|
||||||
"trade_date": date,
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
propagator.create_initial_state = Mock(return_value=state)
|
propagator.create_initial_state = Mock(return_value=state)
|
||||||
|
|
||||||
result = propagator.create_initial_state("AAPL", date)
|
result = propagator.create_initial_state("AAPL", date)
|
||||||
assert result["trade_date"] == date
|
assert result["trade_date"] == date
|
||||||
|
|
||||||
def test_propagate_error_handling(self):
|
def test_propagate_error_handling(self):
|
||||||
"""Test error handling in propagation."""
|
"""Test error handling in propagation."""
|
||||||
propagator = Mock()
|
propagator = Mock()
|
||||||
|
|
||||||
# Simulate error
|
# Simulate error
|
||||||
propagator.create_initial_state = Mock(side_effect=ValueError("Invalid ticker"))
|
propagator.create_initial_state = Mock(side_effect=ValueError("Invalid ticker"))
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
propagator.create_initial_state("INVALID", "2024-05-10")
|
propagator.create_initial_state("INVALID", "2024-05-10")
|
||||||
|
|
||||||
propagator.create_initial_state.assert_called_once()
|
propagator.create_initial_state.assert_called_once()
|
||||||
|
|
||||||
def test_graph_args_with_custom_config(self):
|
def test_graph_args_with_custom_config(self):
|
||||||
"""Test graph args with custom configuration."""
|
"""Test graph args with custom configuration."""
|
||||||
propagator = Mock()
|
propagator = Mock()
|
||||||
|
|
||||||
custom_config = {
|
custom_config = {
|
||||||
"recursion_limit": 200,
|
"recursion_limit": 200,
|
||||||
"config": {
|
"config": {"tags": ["custom", "test"], "metadata": {"version": "1.0"}},
|
||||||
"tags": ["custom", "test"],
|
|
||||||
"metadata": {"version": "1.0"}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
propagator.get_graph_args = Mock(return_value=custom_config)
|
propagator.get_graph_args = Mock(return_value=custom_config)
|
||||||
|
|
||||||
args = propagator.get_graph_args()
|
args = propagator.get_graph_args()
|
||||||
assert args["recursion_limit"] == 200
|
assert args["recursion_limit"] == 200
|
||||||
assert "custom" in args["config"]["tags"]
|
assert "custom" in args["config"]["tags"]
|
||||||
|
|
@ -151,10 +144,10 @@ class TestPropagator:
|
||||||
def test_initial_state_completeness(self):
|
def test_initial_state_completeness(self):
|
||||||
"""Test that initial state contains all required fields."""
|
"""Test that initial state contains all required fields."""
|
||||||
propagator = Mock()
|
propagator = Mock()
|
||||||
|
|
||||||
required_fields = [
|
required_fields = [
|
||||||
"company_of_interest",
|
"company_of_interest",
|
||||||
"trade_date",
|
"trade_date",
|
||||||
"messages",
|
"messages",
|
||||||
"market_report",
|
"market_report",
|
||||||
"sentiment_report",
|
"sentiment_report",
|
||||||
|
|
@ -164,17 +157,17 @@ class TestPropagator:
|
||||||
"trader_investment_plan",
|
"trader_investment_plan",
|
||||||
"risk_debate_state",
|
"risk_debate_state",
|
||||||
"investment_plan",
|
"investment_plan",
|
||||||
"final_trade_decision"
|
"final_trade_decision",
|
||||||
]
|
]
|
||||||
|
|
||||||
state = {field: "" for field in required_fields}
|
state = {field: "" for field in required_fields}
|
||||||
state["messages"] = []
|
state["messages"] = []
|
||||||
state["investment_debate_state"] = {}
|
state["investment_debate_state"] = {}
|
||||||
state["risk_debate_state"] = {}
|
state["risk_debate_state"] = {}
|
||||||
|
|
||||||
propagator.create_initial_state = Mock(return_value=state)
|
propagator.create_initial_state = Mock(return_value=state)
|
||||||
|
|
||||||
result = propagator.create_initial_state("AAPL", "2024-05-10")
|
result = propagator.create_initial_state("AAPL", "2024-05-10")
|
||||||
|
|
||||||
for field in required_fields:
|
for field in required_fields:
|
||||||
assert field in result, f"Missing required field: {field}"
|
assert field in result, f"Missing required field: {field}"
|
||||||
|
|
|
||||||
|
|
@ -47,18 +47,18 @@ class TestReflector:
|
||||||
"""Test Reflector initialization."""
|
"""Test Reflector initialization."""
|
||||||
reflector = Mock()
|
reflector = Mock()
|
||||||
reflector.llm = mock_llm
|
reflector.llm = mock_llm
|
||||||
|
|
||||||
assert reflector.llm == mock_llm
|
assert reflector.llm == mock_llm
|
||||||
|
|
||||||
def test_reflect_bull_researcher(self, mock_llm, mock_memory, sample_state):
|
def test_reflect_bull_researcher(self, mock_llm, mock_memory, sample_state):
|
||||||
"""Test reflection for bull researcher."""
|
"""Test reflection for bull researcher."""
|
||||||
reflector = Mock()
|
reflector = Mock()
|
||||||
reflector.reflect_bull_researcher = Mock()
|
reflector.reflect_bull_researcher = Mock()
|
||||||
|
|
||||||
returns_losses = {"return": 0.05, "loss": -0.02}
|
returns_losses = {"return": 0.05, "loss": -0.02}
|
||||||
|
|
||||||
reflector.reflect_bull_researcher(sample_state, returns_losses, mock_memory)
|
reflector.reflect_bull_researcher(sample_state, returns_losses, mock_memory)
|
||||||
|
|
||||||
reflector.reflect_bull_researcher.assert_called_once_with(
|
reflector.reflect_bull_researcher.assert_called_once_with(
|
||||||
sample_state, returns_losses, mock_memory
|
sample_state, returns_losses, mock_memory
|
||||||
)
|
)
|
||||||
|
|
@ -67,83 +67,87 @@ class TestReflector:
|
||||||
"""Test reflection for bear researcher."""
|
"""Test reflection for bear researcher."""
|
||||||
reflector = Mock()
|
reflector = Mock()
|
||||||
reflector.reflect_bear_researcher = Mock()
|
reflector.reflect_bear_researcher = Mock()
|
||||||
|
|
||||||
returns_losses = {"return": -0.03, "loss": -0.05}
|
returns_losses = {"return": -0.03, "loss": -0.05}
|
||||||
|
|
||||||
reflector.reflect_bear_researcher(sample_state, returns_losses, mock_memory)
|
reflector.reflect_bear_researcher(sample_state, returns_losses, mock_memory)
|
||||||
|
|
||||||
reflector.reflect_bear_researcher.assert_called_once()
|
reflector.reflect_bear_researcher.assert_called_once()
|
||||||
|
|
||||||
def test_reflect_trader(self, mock_llm, mock_memory, sample_state):
|
def test_reflect_trader(self, mock_llm, mock_memory, sample_state):
|
||||||
"""Test reflection for trader."""
|
"""Test reflection for trader."""
|
||||||
reflector = Mock()
|
reflector = Mock()
|
||||||
reflector.reflect_trader = Mock()
|
reflector.reflect_trader = Mock()
|
||||||
|
|
||||||
returns_losses = {"return": 0.10, "loss": 0.0}
|
returns_losses = {"return": 0.10, "loss": 0.0}
|
||||||
|
|
||||||
reflector.reflect_trader(sample_state, returns_losses, mock_memory)
|
reflector.reflect_trader(sample_state, returns_losses, mock_memory)
|
||||||
|
|
||||||
reflector.reflect_trader.assert_called_once()
|
reflector.reflect_trader.assert_called_once()
|
||||||
|
|
||||||
def test_reflect_invest_judge(self, mock_llm, mock_memory, sample_state):
|
def test_reflect_invest_judge(self, mock_llm, mock_memory, sample_state):
|
||||||
"""Test reflection for investment judge."""
|
"""Test reflection for investment judge."""
|
||||||
reflector = Mock()
|
reflector = Mock()
|
||||||
reflector.reflect_invest_judge = Mock()
|
reflector.reflect_invest_judge = Mock()
|
||||||
|
|
||||||
returns_losses = {"return": 0.02, "loss": -0.01}
|
returns_losses = {"return": 0.02, "loss": -0.01}
|
||||||
|
|
||||||
reflector.reflect_invest_judge(sample_state, returns_losses, mock_memory)
|
reflector.reflect_invest_judge(sample_state, returns_losses, mock_memory)
|
||||||
|
|
||||||
reflector.reflect_invest_judge.assert_called_once()
|
reflector.reflect_invest_judge.assert_called_once()
|
||||||
|
|
||||||
def test_reflect_risk_manager(self, mock_llm, mock_memory, sample_state):
|
def test_reflect_risk_manager(self, mock_llm, mock_memory, sample_state):
|
||||||
"""Test reflection for risk manager."""
|
"""Test reflection for risk manager."""
|
||||||
reflector = Mock()
|
reflector = Mock()
|
||||||
reflector.reflect_risk_manager = Mock()
|
reflector.reflect_risk_manager = Mock()
|
||||||
|
|
||||||
returns_losses = {"return": -0.05, "loss": -0.10}
|
returns_losses = {"return": -0.05, "loss": -0.10}
|
||||||
|
|
||||||
reflector.reflect_risk_manager(sample_state, returns_losses, mock_memory)
|
reflector.reflect_risk_manager(sample_state, returns_losses, mock_memory)
|
||||||
|
|
||||||
reflector.reflect_risk_manager.assert_called_once()
|
reflector.reflect_risk_manager.assert_called_once()
|
||||||
|
|
||||||
def test_reflection_with_positive_returns(self, mock_llm, mock_memory, sample_state):
|
def test_reflection_with_positive_returns(
|
||||||
|
self, mock_llm, mock_memory, sample_state
|
||||||
|
):
|
||||||
"""Test reflection with positive returns."""
|
"""Test reflection with positive returns."""
|
||||||
reflector = Mock()
|
reflector = Mock()
|
||||||
|
|
||||||
# Mock all reflection methods
|
# Mock all reflection methods
|
||||||
reflector.reflect_bull_researcher = Mock(return_value="Positive reflection")
|
reflector.reflect_bull_researcher = Mock(return_value="Positive reflection")
|
||||||
reflector.reflect_bear_researcher = Mock(return_value="Positive reflection")
|
reflector.reflect_bear_researcher = Mock(return_value="Positive reflection")
|
||||||
reflector.reflect_trader = Mock(return_value="Positive reflection")
|
reflector.reflect_trader = Mock(return_value="Positive reflection")
|
||||||
|
|
||||||
returns_losses = {"return": 0.15, "loss": 0.0}
|
returns_losses = {"return": 0.15, "loss": 0.0}
|
||||||
|
|
||||||
# Call all reflections
|
# Call all reflections
|
||||||
reflector.reflect_bull_researcher(sample_state, returns_losses, mock_memory)
|
reflector.reflect_bull_researcher(sample_state, returns_losses, mock_memory)
|
||||||
reflector.reflect_bear_researcher(sample_state, returns_losses, mock_memory)
|
reflector.reflect_bear_researcher(sample_state, returns_losses, mock_memory)
|
||||||
reflector.reflect_trader(sample_state, returns_losses, mock_memory)
|
reflector.reflect_trader(sample_state, returns_losses, mock_memory)
|
||||||
|
|
||||||
# Verify all were called
|
# Verify all were called
|
||||||
assert reflector.reflect_bull_researcher.called
|
assert reflector.reflect_bull_researcher.called
|
||||||
assert reflector.reflect_bear_researcher.called
|
assert reflector.reflect_bear_researcher.called
|
||||||
assert reflector.reflect_trader.called
|
assert reflector.reflect_trader.called
|
||||||
|
|
||||||
def test_reflection_with_negative_returns(self, mock_llm, mock_memory, sample_state):
|
def test_reflection_with_negative_returns(
|
||||||
|
self, mock_llm, mock_memory, sample_state
|
||||||
|
):
|
||||||
"""Test reflection with negative returns."""
|
"""Test reflection with negative returns."""
|
||||||
reflector = Mock()
|
reflector = Mock()
|
||||||
|
|
||||||
# Mock reflection methods
|
# Mock reflection methods
|
||||||
reflector.reflect_bull_researcher = Mock(return_value="Negative reflection")
|
reflector.reflect_bull_researcher = Mock(return_value="Negative reflection")
|
||||||
reflector.reflect_bear_researcher = Mock(return_value="Negative reflection")
|
reflector.reflect_bear_researcher = Mock(return_value="Negative reflection")
|
||||||
reflector.reflect_risk_manager = Mock(return_value="Risk reflection")
|
reflector.reflect_risk_manager = Mock(return_value="Risk reflection")
|
||||||
|
|
||||||
returns_losses = {"return": -0.08, "loss": -0.15}
|
returns_losses = {"return": -0.08, "loss": -0.15}
|
||||||
|
|
||||||
# Call reflections
|
# Call reflections
|
||||||
reflector.reflect_bull_researcher(sample_state, returns_losses, mock_memory)
|
reflector.reflect_bull_researcher(sample_state, returns_losses, mock_memory)
|
||||||
reflector.reflect_bear_researcher(sample_state, returns_losses, mock_memory)
|
reflector.reflect_bear_researcher(sample_state, returns_losses, mock_memory)
|
||||||
reflector.reflect_risk_manager(sample_state, returns_losses, mock_memory)
|
reflector.reflect_risk_manager(sample_state, returns_losses, mock_memory)
|
||||||
|
|
||||||
# Verify all were called
|
# Verify all were called
|
||||||
assert reflector.reflect_bull_researcher.call_count == 1
|
assert reflector.reflect_bull_researcher.call_count == 1
|
||||||
assert reflector.reflect_bear_researcher.call_count == 1
|
assert reflector.reflect_bear_researcher.call_count == 1
|
||||||
|
|
@ -152,47 +156,46 @@ class TestReflector:
|
||||||
def test_reflection_memory_update(self, mock_llm, mock_memory):
|
def test_reflection_memory_update(self, mock_llm, mock_memory):
|
||||||
"""Test that reflection updates memory correctly."""
|
"""Test that reflection updates memory correctly."""
|
||||||
reflector = Mock()
|
reflector = Mock()
|
||||||
|
|
||||||
def mock_reflect(state, returns, memory):
|
def mock_reflect(state, returns, memory):
|
||||||
reflection = f"Reflection for {state['company_of_interest']}"
|
reflection = f"Reflection for {state['company_of_interest']}"
|
||||||
memory.add_memory(reflection)
|
memory.add_memory(reflection)
|
||||||
return reflection
|
return reflection
|
||||||
|
|
||||||
reflector.reflect_trader = Mock(side_effect=mock_reflect)
|
reflector.reflect_trader = Mock(side_effect=mock_reflect)
|
||||||
|
|
||||||
state = {"company_of_interest": "TSLA"}
|
state = {"company_of_interest": "TSLA"}
|
||||||
returns_losses = {"return": 0.05, "loss": 0.0}
|
returns_losses = {"return": 0.05, "loss": 0.0}
|
||||||
|
|
||||||
reflector.reflect_trader(state, returns_losses, mock_memory)
|
reflector.reflect_trader(state, returns_losses, mock_memory)
|
||||||
|
|
||||||
mock_memory.add_memory.assert_called_once()
|
mock_memory.add_memory.assert_called_once()
|
||||||
|
|
||||||
def test_reflection_with_different_decisions(self, mock_llm, mock_memory):
|
def test_reflection_with_different_decisions(self, mock_llm, mock_memory):
|
||||||
"""Test reflection with different trading decisions."""
|
"""Test reflection with different trading decisions."""
|
||||||
reflector = Mock()
|
reflector = Mock()
|
||||||
reflector.reflect_trader = Mock()
|
reflector.reflect_trader = Mock()
|
||||||
|
|
||||||
decisions = ["BUY", "SELL", "HOLD"]
|
decisions = ["BUY", "SELL", "HOLD"]
|
||||||
|
|
||||||
for decision in decisions:
|
for decision in decisions:
|
||||||
state = {
|
state = {"final_trade_decision": decision, "company_of_interest": "AAPL"}
|
||||||
"final_trade_decision": decision,
|
|
||||||
"company_of_interest": "AAPL"
|
|
||||||
}
|
|
||||||
returns_losses = {"return": 0.03, "loss": -0.01}
|
returns_losses = {"return": 0.03, "loss": -0.01}
|
||||||
|
|
||||||
reflector.reflect_trader(state, returns_losses, mock_memory)
|
reflector.reflect_trader(state, returns_losses, mock_memory)
|
||||||
|
|
||||||
assert reflector.reflect_trader.call_count == 3
|
assert reflector.reflect_trader.call_count == 3
|
||||||
|
|
||||||
def test_reflection_error_handling(self, mock_llm, mock_memory, sample_state):
|
def test_reflection_error_handling(self, mock_llm, mock_memory, sample_state):
|
||||||
"""Test error handling in reflection."""
|
"""Test error handling in reflection."""
|
||||||
reflector = Mock()
|
reflector = Mock()
|
||||||
|
|
||||||
# Simulate error in reflection
|
# Simulate error in reflection
|
||||||
reflector.reflect_bull_researcher = Mock(side_effect=Exception("Reflection error"))
|
reflector.reflect_bull_researcher = Mock(
|
||||||
|
side_effect=Exception("Reflection error")
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
reflector.reflect_bull_researcher(sample_state, {}, mock_memory)
|
reflector.reflect_bull_researcher(sample_state, {}, mock_memory)
|
||||||
|
|
||||||
reflector.reflect_bull_researcher.assert_called_once()
|
reflector.reflect_bull_researcher.assert_called_once()
|
||||||
|
|
|
||||||
|
|
@ -10,14 +10,16 @@ class TestSignalProcessor:
|
||||||
def test_signal_processor_initialization(self):
|
def test_signal_processor_initialization(self):
|
||||||
"""Test SignalProcessor initialization."""
|
"""Test SignalProcessor initialization."""
|
||||||
mock_llm = Mock()
|
mock_llm = Mock()
|
||||||
|
|
||||||
# Import with mocked dependencies to avoid pandas import
|
# Import with mocked dependencies to avoid pandas import
|
||||||
with patch('sys.modules', {'pandas': Mock(), 'yfinance': Mock(), 'openai': Mock()}):
|
with patch(
|
||||||
|
"sys.modules", {"pandas": Mock(), "yfinance": Mock(), "openai": Mock()}
|
||||||
|
):
|
||||||
# This would normally import SignalProcessor
|
# This would normally import SignalProcessor
|
||||||
# from tradingagents.graph.signal_processing import SignalProcessor
|
# from tradingagents.graph.signal_processing import SignalProcessor
|
||||||
# processor = SignalProcessor(mock_llm)
|
# processor = SignalProcessor(mock_llm)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert True # Placeholder
|
assert True # Placeholder
|
||||||
|
|
||||||
def test_process_signal_buy(self):
|
def test_process_signal_buy(self):
|
||||||
|
|
@ -25,7 +27,7 @@ class TestSignalProcessor:
|
||||||
# Create mock processor
|
# Create mock processor
|
||||||
processor = Mock()
|
processor = Mock()
|
||||||
processor.process_signal = Mock(return_value="BUY")
|
processor.process_signal = Mock(return_value="BUY")
|
||||||
|
|
||||||
result = processor.process_signal("Recommend BUY based on analysis")
|
result = processor.process_signal("Recommend BUY based on analysis")
|
||||||
assert result == "BUY"
|
assert result == "BUY"
|
||||||
processor.process_signal.assert_called_once()
|
processor.process_signal.assert_called_once()
|
||||||
|
|
@ -34,7 +36,7 @@ class TestSignalProcessor:
|
||||||
"""Test processing SELL signal."""
|
"""Test processing SELL signal."""
|
||||||
processor = Mock()
|
processor = Mock()
|
||||||
processor.process_signal = Mock(return_value="SELL")
|
processor.process_signal = Mock(return_value="SELL")
|
||||||
|
|
||||||
result = processor.process_signal("Recommend SELL based on analysis")
|
result = processor.process_signal("Recommend SELL based on analysis")
|
||||||
assert result == "SELL"
|
assert result == "SELL"
|
||||||
|
|
||||||
|
|
@ -42,7 +44,7 @@ class TestSignalProcessor:
|
||||||
"""Test processing HOLD signal."""
|
"""Test processing HOLD signal."""
|
||||||
processor = Mock()
|
processor = Mock()
|
||||||
processor.process_signal = Mock(return_value="HOLD")
|
processor.process_signal = Mock(return_value="HOLD")
|
||||||
|
|
||||||
result = processor.process_signal("Recommend HOLD based on analysis")
|
result = processor.process_signal("Recommend HOLD based on analysis")
|
||||||
assert result == "HOLD"
|
assert result == "HOLD"
|
||||||
|
|
||||||
|
|
@ -50,7 +52,7 @@ class TestSignalProcessor:
|
||||||
"""Test processing signal with confidence score."""
|
"""Test processing signal with confidence score."""
|
||||||
processor = Mock()
|
processor = Mock()
|
||||||
processor.process_signal = Mock(return_value="BUY")
|
processor.process_signal = Mock(return_value="BUY")
|
||||||
|
|
||||||
signal = "BUY with confidence 0.85"
|
signal = "BUY with confidence 0.85"
|
||||||
result = processor.process_signal(signal)
|
result = processor.process_signal(signal)
|
||||||
assert result == "BUY"
|
assert result == "BUY"
|
||||||
|
|
@ -59,22 +61,22 @@ class TestSignalProcessor:
|
||||||
"""Test processing invalid signal."""
|
"""Test processing invalid signal."""
|
||||||
processor = Mock()
|
processor = Mock()
|
||||||
processor.process_signal = Mock(return_value="HOLD") # Default to HOLD
|
processor.process_signal = Mock(return_value="HOLD") # Default to HOLD
|
||||||
|
|
||||||
result = processor.process_signal("Invalid signal text")
|
result = processor.process_signal("Invalid signal text")
|
||||||
assert result == "HOLD"
|
assert result == "HOLD"
|
||||||
|
|
||||||
def test_extract_decision_from_text(self):
|
def test_extract_decision_from_text(self):
|
||||||
"""Test extracting decision from complex text."""
|
"""Test extracting decision from complex text."""
|
||||||
processor = Mock()
|
processor = Mock()
|
||||||
|
|
||||||
test_cases = [
|
test_cases = [
|
||||||
("After analysis, I recommend BUY", "BUY"),
|
("After analysis, I recommend BUY", "BUY"),
|
||||||
("The decision is to SELL immediately", "SELL"),
|
("The decision is to SELL immediately", "SELL"),
|
||||||
("Best action: HOLD position", "HOLD"),
|
("Best action: HOLD position", "HOLD"),
|
||||||
("FINAL TRANSACTION PROPOSAL: **BUY**", "BUY"),
|
("FINAL TRANSACTION PROPOSAL: **BUY**", "BUY"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for text, expected in test_cases:
|
for text, expected in test_cases:
|
||||||
processor.process_signal = Mock(return_value=expected)
|
processor.process_signal = Mock(return_value=expected)
|
||||||
result = processor.process_signal(text)
|
result = processor.process_signal(text)
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from unittest.mock import Mock, mock_open, patch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
from .mock_toolkit_fix import patch_toolkit_in_test
|
||||||
|
|
||||||
|
|
||||||
class TestTradingAgentsGraph:
|
class TestTradingAgentsGraph:
|
||||||
|
|
@ -26,9 +27,8 @@ class TestTradingAgentsGraph:
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
mock_llm = Mock()
|
mock_llm = Mock()
|
||||||
mock_chat_openai.return_value = mock_llm
|
mock_chat_openai.return_value = mock_llm
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = patch_toolkit_in_test(mock_toolkit)
|
||||||
mock_toolkit_instance.config = sample_config
|
mock_toolkit_instance.config = sample_config
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
|
|
@ -56,7 +56,7 @@ class TestTradingAgentsGraph:
|
||||||
mock_llm = Mock()
|
mock_llm = Mock()
|
||||||
mock_chat_openai.return_value = mock_llm
|
mock_chat_openai.return_value = mock_llm
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
patch_toolkit_in_test(mock_toolkit)
|
||||||
|
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
with patch("tradingagents.dataflows.config.set_config"):
|
with patch("tradingagents.dataflows.config.set_config"):
|
||||||
|
|
@ -79,7 +79,7 @@ class TestTradingAgentsGraph:
|
||||||
mock_llm = Mock()
|
mock_llm = Mock()
|
||||||
mock_chat_anthropic.return_value = mock_llm
|
mock_chat_anthropic.return_value = mock_llm
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
patch_toolkit_in_test(mock_toolkit)
|
||||||
|
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
with patch("tradingagents.dataflows.config.set_config"):
|
with patch("tradingagents.dataflows.config.set_config"):
|
||||||
|
|
@ -102,7 +102,7 @@ class TestTradingAgentsGraph:
|
||||||
mock_llm = Mock()
|
mock_llm = Mock()
|
||||||
mock_chat_google.return_value = mock_llm
|
mock_chat_google.return_value = mock_llm
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
patch_toolkit_in_test(mock_toolkit)
|
||||||
|
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
with patch("tradingagents.dataflows.config.set_config"):
|
with patch("tradingagents.dataflows.config.set_config"):
|
||||||
|
|
@ -121,7 +121,7 @@ class TestTradingAgentsGraph:
|
||||||
sample_config["project_dir"] = temp_data_dir
|
sample_config["project_dir"] = temp_data_dir
|
||||||
sample_config["llm_provider"] = "unsupported"
|
sample_config["llm_provider"] = "unsupported"
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
patch_toolkit_in_test(mock_toolkit)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Unsupported LLM provider"):
|
with pytest.raises(ValueError, match="Unsupported LLM provider"):
|
||||||
with patch("tradingagents.dataflows.config.set_config"):
|
with patch("tradingagents.dataflows.config.set_config"):
|
||||||
|
|
@ -147,7 +147,7 @@ class TestTradingAgentsGraph:
|
||||||
mock_toolkit_instance.get_YFin_data = Mock()
|
mock_toolkit_instance.get_YFin_data = Mock()
|
||||||
mock_toolkit_instance.get_stockstats_indicators_report_online = Mock()
|
mock_toolkit_instance.get_stockstats_indicators_report_online = Mock()
|
||||||
mock_toolkit_instance.get_stockstats_indicators_report = Mock()
|
mock_toolkit_instance.get_stockstats_indicators_report = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
patch_toolkit_in_test(mock_toolkit)
|
||||||
|
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
with patch("tradingagents.dataflows.config.set_config"):
|
with patch("tradingagents.dataflows.config.set_config"):
|
||||||
|
|
@ -173,7 +173,7 @@ class TestTradingAgentsGraph:
|
||||||
mock_llm = Mock()
|
mock_llm = Mock()
|
||||||
mock_chat_openai.return_value = mock_llm
|
mock_chat_openai.return_value = mock_llm
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
patch_toolkit_in_test(mock_toolkit)
|
||||||
|
|
||||||
# Mock the graph and its invoke method
|
# Mock the graph and its invoke method
|
||||||
mock_graph = Mock()
|
mock_graph = Mock()
|
||||||
|
|
@ -240,7 +240,7 @@ class TestTradingAgentsGraph:
|
||||||
mock_llm = Mock()
|
mock_llm = Mock()
|
||||||
mock_chat_openai.return_value = mock_llm
|
mock_chat_openai.return_value = mock_llm
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
patch_toolkit_in_test(mock_toolkit)
|
||||||
|
|
||||||
# Mock the graph stream method for debug mode
|
# Mock the graph stream method for debug mode
|
||||||
mock_graph = Mock()
|
mock_graph = Mock()
|
||||||
|
|
@ -283,7 +283,7 @@ class TestTradingAgentsGraph:
|
||||||
mock_llm = Mock()
|
mock_llm = Mock()
|
||||||
mock_chat_openai.return_value = mock_llm
|
mock_chat_openai.return_value = mock_llm
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
patch_toolkit_in_test(mock_toolkit)
|
||||||
|
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
with patch("tradingagents.dataflows.config.set_config"):
|
with patch("tradingagents.dataflows.config.set_config"):
|
||||||
|
|
@ -342,7 +342,7 @@ class TestTradingAgentsGraph:
|
||||||
mock_llm = Mock()
|
mock_llm = Mock()
|
||||||
mock_chat_openai.return_value = mock_llm
|
mock_chat_openai.return_value = mock_llm
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
patch_toolkit_in_test(mock_toolkit)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
|
|
@ -388,7 +388,7 @@ class TestTradingAgentsGraph:
|
||||||
mock_llm = Mock()
|
mock_llm = Mock()
|
||||||
mock_chat_openai.return_value = mock_llm
|
mock_chat_openai.return_value = mock_llm
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
patch_toolkit_in_test(mock_toolkit)
|
||||||
|
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
with patch("tradingagents.dataflows.config.set_config"):
|
with patch("tradingagents.dataflows.config.set_config"):
|
||||||
|
|
@ -425,7 +425,7 @@ class TestTradingAgentsGraph:
|
||||||
mock_llm = Mock()
|
mock_llm = Mock()
|
||||||
mock_chat_openai.return_value = mock_llm
|
mock_chat_openai.return_value = mock_llm
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
patch_toolkit_in_test(mock_toolkit)
|
||||||
|
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
with patch("tradingagents.dataflows.config.set_config"):
|
with patch("tradingagents.dataflows.config.set_config"):
|
||||||
|
|
@ -447,7 +447,7 @@ class TestTradingAgentsGraphErrorHandling:
|
||||||
"""Test handling of invalid configuration."""
|
"""Test handling of invalid configuration."""
|
||||||
invalid_config = {"invalid_key": "invalid_value"}
|
invalid_config = {"invalid_key": "invalid_value"}
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
patch_toolkit_in_test(mock_toolkit)
|
||||||
|
|
||||||
# This should still work as the class should use defaults for missing keys
|
# This should still work as the class should use defaults for missing keys
|
||||||
with patch("tradingagents.dataflows.config.set_config"):
|
with patch("tradingagents.dataflows.config.set_config"):
|
||||||
|
|
@ -469,7 +469,7 @@ class TestTradingAgentsGraphErrorHandling:
|
||||||
mock_llm = Mock()
|
mock_llm = Mock()
|
||||||
mock_chat_openai.return_value = mock_llm
|
mock_chat_openai.return_value = mock_llm
|
||||||
mock_toolkit_instance = Mock()
|
mock_toolkit_instance = Mock()
|
||||||
mock_toolkit.return_value = mock_toolkit_instance
|
patch_toolkit_in_test(mock_toolkit)
|
||||||
|
|
||||||
# Should handle directory creation gracefully or raise appropriate error
|
# Should handle directory creation gracefully or raise appropriate error
|
||||||
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue