Fix CI/CD test failures

- Apply Black formatting to all test files
- Fix Mock objects to include tool_calls attribute for len() checks
- Add proper __name__ attributes to mock toolkit methods for @tool decorator
- Create mock_toolkit_fix helper for consistent toolkit mocking

All tests should now pass with proper mocking setup.
This commit is contained in:
佐藤優一 2025-08-11 10:42:13 +09:00
parent ba958c20e5
commit 0ab8c8fc46
9 changed files with 368 additions and 207 deletions

100
test_graph_fix.py Normal file
View File

@ -0,0 +1,100 @@
#!/usr/bin/env python
"""Test that mock toolkit fixes work for TradingAgentsGraph."""
from unittest.mock import Mock, patch
import sys
import os
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from tests.unit.graph.mock_toolkit_fix import create_mock_toolkit_with_tools
def test_mock_toolkit_has_all_methods():
"""Test that the mock toolkit has all required methods."""
toolkit = create_mock_toolkit_with_tools()
required_methods = [
"get_YFin_data",
"get_YFin_data_online",
"get_stockstats_indicators_report",
"get_stockstats_indicators_report_online",
"get_reddit_stock_info",
"get_stock_news_openai",
]
for method_name in required_methods:
assert hasattr(toolkit, method_name), f"Missing {method_name}"
method = getattr(toolkit, method_name)
assert hasattr(method, '__name__'), f"{method_name} missing __name__"
assert method.__name__ == method_name, f"{method_name} has wrong __name__"
assert callable(method), f"{method_name} is not callable"
print("✓ Mock toolkit has all required methods with proper attributes")
return True
def test_tool_node_creation():
"""Test that ToolNode can be created with mocked toolkit methods."""
# Mock the ToolNode class
with patch("langgraph.prebuilt.ToolNode") as MockToolNode:
MockToolNode.return_value = Mock()
toolkit = create_mock_toolkit_with_tools()
# Simulate creating tool nodes like in TradingAgentsGraph
from langgraph.prebuilt import ToolNode
tool_node = ToolNode([
toolkit.get_YFin_data,
toolkit.get_stockstats_indicators_report,
])
# Should not raise an error
assert MockToolNode.called
print("✓ ToolNode can be created with mocked toolkit methods")
return True
def test_tool_decorator():
"""Test that @tool decorator works with mocked functions."""
toolkit = create_mock_toolkit_with_tools()
# The @tool decorator expects __name__ attribute
for attr_name in dir(toolkit):
if attr_name.startswith('get_'):
method = getattr(toolkit, attr_name)
assert hasattr(method, '__name__'), f"{attr_name} missing __name__"
print("✓ All toolkit methods are compatible with @tool decorator")
return True
if __name__ == "__main__":
print("Testing mock toolkit fixes for TradingAgentsGraph...")
print("-" * 50)
tests = [
test_mock_toolkit_has_all_methods,
test_tool_node_creation,
test_tool_decorator,
]
all_passed = True
for test in tests:
try:
if not test():
all_passed = False
print(f"{test.__name__} failed")
except Exception as e:
all_passed = False
print(f"{test.__name__} raised exception: {e}")
import traceback
traceback.print_exc()
print("-" * 50)
if all_passed:
print("✅ All tests passed! TradingAgentsGraph mock fixes are working.")
else:
print("❌ Some tests failed. Check the output above.")

View File

@ -31,7 +31,13 @@ def mock_llm():
"""Mock LLM for testing.""" """Mock LLM for testing."""
mock = Mock() mock = Mock()
mock.model_name = "test-model" mock.model_name = "test-model"
mock.invoke.return_value = Mock(content="Test response")
# Create a mock result with tool_calls attribute
mock_result = Mock()
mock_result.content = "Test response"
mock_result.tool_calls = [] # Add tool_calls attribute for len() check
mock.invoke.return_value = mock_result
mock.bind_tools.return_value = mock mock.bind_tools.return_value = mock
return mock return mock
@ -154,7 +160,9 @@ def mock_toolkit():
side_effect=mock_get_stockstats_indicators_report side_effect=mock_get_stockstats_indicators_report
) )
toolkit.get_stockstats_indicators_report.name = "get_stockstats_indicators_report" toolkit.get_stockstats_indicators_report.name = "get_stockstats_indicators_report"
toolkit.get_stockstats_indicators_report.__name__ = "get_stockstats_indicators_report" toolkit.get_stockstats_indicators_report.__name__ = (
"get_stockstats_indicators_report"
)
toolkit.get_stockstats_indicators_report_online = Mock( toolkit.get_stockstats_indicators_report_online = Mock(
side_effect=mock_get_stockstats_indicators_report_online side_effect=mock_get_stockstats_indicators_report_online

View File

@ -38,21 +38,33 @@ class TestMarketAnalystExtended:
toolkit.get_YFin_data.name = "get_YFin_data" toolkit.get_YFin_data.name = "get_YFin_data"
toolkit.get_stockstats_indicators_report = Mock(side_effect=mock_stockstats) toolkit.get_stockstats_indicators_report = Mock(side_effect=mock_stockstats)
toolkit.get_stockstats_indicators_report.__name__ = "get_stockstats_indicators_report" toolkit.get_stockstats_indicators_report.__name__ = (
toolkit.get_stockstats_indicators_report.name = "get_stockstats_indicators_report" "get_stockstats_indicators_report"
)
toolkit.get_stockstats_indicators_report.name = (
"get_stockstats_indicators_report"
)
# Online versions # Online versions
toolkit.get_YFin_data_online = Mock(side_effect=mock_yfin) toolkit.get_YFin_data_online = Mock(side_effect=mock_yfin)
toolkit.get_YFin_data_online.__name__ = "get_YFin_data_online" toolkit.get_YFin_data_online.__name__ = "get_YFin_data_online"
toolkit.get_YFin_data_online.name = "get_YFin_data_online" toolkit.get_YFin_data_online.name = "get_YFin_data_online"
toolkit.get_stockstats_indicators_report_online = Mock(side_effect=mock_stockstats) toolkit.get_stockstats_indicators_report_online = Mock(
toolkit.get_stockstats_indicators_report_online.__name__ = "get_stockstats_indicators_report_online" side_effect=mock_stockstats
toolkit.get_stockstats_indicators_report_online.name = "get_stockstats_indicators_report_online" )
toolkit.get_stockstats_indicators_report_online.__name__ = (
"get_stockstats_indicators_report_online"
)
toolkit.get_stockstats_indicators_report_online.name = (
"get_stockstats_indicators_report_online"
)
return toolkit return toolkit
def test_market_analyst_system_message(self, mock_llm_extended, mock_toolkit_extended): def test_market_analyst_system_message(
self, mock_llm_extended, mock_toolkit_extended
):
"""Test that system message is properly formatted.""" """Test that system message is properly formatted."""
# This would normally import and test the actual function # This would normally import and test the actual function
# For now, we test the mock behavior # For now, we test the mock behavior
@ -60,7 +72,7 @@ class TestMarketAnalystExtended:
state = { state = {
"company_of_interest": "AAPL", "company_of_interest": "AAPL",
"trade_date": "2024-05-10", "trade_date": "2024-05-10",
"messages": [] "messages": [],
} }
# Simulate creating analyst # Simulate creating analyst
@ -71,12 +83,14 @@ class TestMarketAnalystExtended:
assert "market_report" in result assert "market_report" in result
assert "messages" in result assert "messages" in result
def test_market_analyst_with_multiple_indicators(self, mock_llm_extended, mock_toolkit_extended): def test_market_analyst_with_multiple_indicators(
self, mock_llm_extended, mock_toolkit_extended
):
"""Test analyst with multiple technical indicators.""" """Test analyst with multiple technical indicators."""
state = { state = {
"company_of_interest": "TSLA", "company_of_interest": "TSLA",
"trade_date": "2024-05-15", "trade_date": "2024-05-15",
"messages": [] "messages": [],
} }
# Mock result with multiple indicators # Mock result with multiple indicators
@ -95,26 +109,27 @@ class TestMarketAnalystExtended:
# Create mock analyst function # Create mock analyst function
def mock_analyst(state): def mock_analyst(state):
return { return {"messages": [mock_result], "market_report": mock_result.content}
"messages": [mock_result],
"market_report": mock_result.content
}
result = mock_analyst(state) result = mock_analyst(state)
assert "RSI" in result["market_report"] assert "RSI" in result["market_report"]
assert "MACD" in result["market_report"] assert "MACD" in result["market_report"]
assert "Bollinger" in result["market_report"] assert "Bollinger" in result["market_report"]
def test_market_analyst_error_handling(self, mock_llm_extended, mock_toolkit_extended): def test_market_analyst_error_handling(
self, mock_llm_extended, mock_toolkit_extended
):
"""Test error handling in market analyst.""" """Test error handling in market analyst."""
state = { state = {
"company_of_interest": "INVALID", "company_of_interest": "INVALID",
"trade_date": "2024-05-10", "trade_date": "2024-05-10",
"messages": [] "messages": [],
} }
# Mock error scenario # Mock error scenario
mock_llm_extended.bind_tools.return_value.invoke.side_effect = Exception("API Error") mock_llm_extended.bind_tools.return_value.invoke.side_effect = Exception(
"API Error"
)
# Create analyst with error handling # Create analyst with error handling
def mock_analyst_with_error_handling(state): def mock_analyst_with_error_handling(state):
@ -122,15 +137,14 @@ class TestMarketAnalystExtended:
# Would call actual analyst here # Would call actual analyst here
raise Exception("API Error") raise Exception("API Error")
except Exception: except Exception:
return { return {"messages": [], "market_report": "Error analyzing market data"}
"messages": [],
"market_report": "Error analyzing market data"
}
result = mock_analyst_with_error_handling(state) result = mock_analyst_with_error_handling(state)
assert result["market_report"] == "Error analyzing market data" assert result["market_report"] == "Error analyzing market data"
def test_market_analyst_date_formatting(self, mock_llm_extended, mock_toolkit_extended): def test_market_analyst_date_formatting(
self, mock_llm_extended, mock_toolkit_extended
):
"""Test various date formats in market analyst.""" """Test various date formats in market analyst."""
test_dates = [ test_dates = [
"2024-01-01", "2024-01-01",
@ -139,11 +153,7 @@ class TestMarketAnalystExtended:
] ]
for date in test_dates: for date in test_dates:
state = { state = {"company_of_interest": "AAPL", "trade_date": date, "messages": []}
"company_of_interest": "AAPL",
"trade_date": date,
"messages": []
}
mock_result = Mock() mock_result = Mock()
mock_result.content = f"Analysis for {date}" mock_result.content = f"Analysis for {date}"
@ -152,13 +162,15 @@ class TestMarketAnalystExtended:
def mock_analyst(state): def mock_analyst(state):
return { return {
"messages": [mock_result], "messages": [mock_result],
"market_report": f"Analysis for {state['trade_date']}" "market_report": f"Analysis for {state['trade_date']}",
} }
result = mock_analyst(state) result = mock_analyst(state)
assert date in result["market_report"] assert date in result["market_report"]
def test_market_analyst_ticker_variations(self, mock_llm_extended, mock_toolkit_extended): def test_market_analyst_ticker_variations(
self, mock_llm_extended, mock_toolkit_extended
):
"""Test analyst with various ticker symbols.""" """Test analyst with various ticker symbols."""
tickers = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA"] tickers = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA"]
@ -166,7 +178,7 @@ class TestMarketAnalystExtended:
state = { state = {
"company_of_interest": ticker, "company_of_interest": ticker,
"trade_date": "2024-05-10", "trade_date": "2024-05-10",
"messages": [] "messages": [],
} }
mock_result = Mock() mock_result = Mock()
@ -176,7 +188,7 @@ class TestMarketAnalystExtended:
def mock_analyst(state): def mock_analyst(state):
return { return {
"messages": [mock_result], "messages": [mock_result],
"market_report": f"Analysis for {state['company_of_interest']}" "market_report": f"Analysis for {state['company_of_interest']}",
} }
result = mock_analyst(state) result = mock_analyst(state)
@ -212,11 +224,7 @@ class TestMarketAnalystExtended:
def test_market_analyst_empty_state(self, mock_llm_extended, mock_toolkit_extended): def test_market_analyst_empty_state(self, mock_llm_extended, mock_toolkit_extended):
"""Test analyst with minimal/empty state.""" """Test analyst with minimal/empty state."""
state = { state = {"company_of_interest": "", "trade_date": "", "messages": []}
"company_of_interest": "",
"trade_date": "",
"messages": []
}
mock_result = Mock() mock_result = Mock()
mock_result.content = "No data available" mock_result.content = "No data available"
@ -224,24 +232,20 @@ class TestMarketAnalystExtended:
def mock_analyst(state): def mock_analyst(state):
if not state["company_of_interest"] or not state["trade_date"]: if not state["company_of_interest"] or not state["trade_date"]:
return { return {"messages": [], "market_report": "No data available"}
"messages": [], return {"messages": [mock_result], "market_report": mock_result.content}
"market_report": "No data available"
}
return {
"messages": [mock_result],
"market_report": mock_result.content
}
result = mock_analyst(state) result = mock_analyst(state)
assert result["market_report"] == "No data available" assert result["market_report"] == "No data available"
def test_market_analyst_tool_calls_tracking(self, mock_llm_extended, mock_toolkit_extended): def test_market_analyst_tool_calls_tracking(
self, mock_llm_extended, mock_toolkit_extended
):
"""Test tracking of tool calls in market analyst.""" """Test tracking of tool calls in market analyst."""
state = { state = {
"company_of_interest": "AAPL", "company_of_interest": "AAPL",
"trade_date": "2024-05-10", "trade_date": "2024-05-10",
"messages": [] "messages": [],
} }
# Mock result with tool calls # Mock result with tool calls
@ -259,10 +263,7 @@ class TestMarketAnalystExtended:
result = mock_llm_extended.bind_tools([]).invoke(state["messages"]) result = mock_llm_extended.bind_tools([]).invoke(state["messages"])
# When tool_calls exist, market_report should be empty # When tool_calls exist, market_report should be empty
report = "" if result.tool_calls else result.content report = "" if result.tool_calls else result.content
return { return {"messages": [result], "market_report": report}
"messages": [result],
"market_report": report
}
result = mock_analyst(state) result = mock_analyst(state)
assert result["market_report"] == "" # Empty when tool calls exist assert result["market_report"] == "" # Empty when tool calls exist

View File

@ -0,0 +1,54 @@
"""Helper to create properly mocked toolkit for test_trading_graph."""
from unittest.mock import Mock
def create_mock_toolkit_with_tools():
"""Create a mock toolkit with all necessary tool methods."""
toolkit = Mock()
toolkit.config = {"online_tools": False}
# List of all methods that need to be mocked
tool_methods = [
# Market tools
"get_YFin_data",
"get_YFin_data_online",
"get_stockstats_indicators_report",
"get_stockstats_indicators_report_online",
# Social tools
"get_reddit_stock_info",
"get_stock_news_openai",
# News tools
"get_global_news_openai",
"get_google_news",
"get_finnhub_news",
"get_reddit_news",
# Fundamentals tools
"get_simfin_cashflow",
"get_simfin_income_stmt",
"get_simfin_balance_sheet",
"get_finnhub_basic_financials",
]
# Create mock for each method with proper __name__ attribute
for method_name in tool_methods:
# Create a function with the right name
def mock_func():
return f"Mock {method_name} data"
# Create Mock wrapping the function
mock_method = Mock(side_effect=mock_func)
mock_method.__name__ = method_name
mock_method.name = method_name
# Set it on the toolkit
setattr(toolkit, method_name, mock_method)
return toolkit
def patch_toolkit_in_test(mock_toolkit):
"""Configure the mock_toolkit patch to return a properly mocked instance."""
mock_instance = create_mock_toolkit_with_tools()
mock_toolkit.return_value = mock_instance
return mock_instance

View File

@ -14,8 +14,8 @@ class TestPropagator:
propagator.create_initial_state = Mock() propagator.create_initial_state = Mock()
propagator.get_graph_args = Mock() propagator.get_graph_args = Mock()
assert hasattr(propagator, 'create_initial_state') assert hasattr(propagator, "create_initial_state")
assert hasattr(propagator, 'get_graph_args') assert hasattr(propagator, "get_graph_args")
assert callable(propagator.create_initial_state) assert callable(propagator.create_initial_state)
assert callable(propagator.get_graph_args) assert callable(propagator.get_graph_args)
@ -93,7 +93,7 @@ class TestPropagator:
state = { state = {
"company_of_interest": ticker, "company_of_interest": ticker,
"trade_date": "2024-05-10", "trade_date": "2024-05-10",
"messages": [] "messages": [],
} }
propagator.create_initial_state = Mock(return_value=state) propagator.create_initial_state = Mock(return_value=state)
@ -107,11 +107,7 @@ class TestPropagator:
dates = ["2024-01-01", "2024-06-15", "2024-12-31"] dates = ["2024-01-01", "2024-06-15", "2024-12-31"]
for date in dates: for date in dates:
state = { state = {"company_of_interest": "AAPL", "trade_date": date, "messages": []}
"company_of_interest": "AAPL",
"trade_date": date,
"messages": []
}
propagator.create_initial_state = Mock(return_value=state) propagator.create_initial_state = Mock(return_value=state)
result = propagator.create_initial_state("AAPL", date) result = propagator.create_initial_state("AAPL", date)
@ -135,10 +131,7 @@ class TestPropagator:
custom_config = { custom_config = {
"recursion_limit": 200, "recursion_limit": 200,
"config": { "config": {"tags": ["custom", "test"], "metadata": {"version": "1.0"}},
"tags": ["custom", "test"],
"metadata": {"version": "1.0"}
}
} }
propagator.get_graph_args = Mock(return_value=custom_config) propagator.get_graph_args = Mock(return_value=custom_config)
@ -164,7 +157,7 @@ class TestPropagator:
"trader_investment_plan", "trader_investment_plan",
"risk_debate_state", "risk_debate_state",
"investment_plan", "investment_plan",
"final_trade_decision" "final_trade_decision",
] ]
state = {field: "" for field in required_fields} state = {field: "" for field in required_fields}

View File

@ -107,7 +107,9 @@ class TestReflector:
reflector.reflect_risk_manager.assert_called_once() reflector.reflect_risk_manager.assert_called_once()
def test_reflection_with_positive_returns(self, mock_llm, mock_memory, sample_state): def test_reflection_with_positive_returns(
self, mock_llm, mock_memory, sample_state
):
"""Test reflection with positive returns.""" """Test reflection with positive returns."""
reflector = Mock() reflector = Mock()
@ -128,7 +130,9 @@ class TestReflector:
assert reflector.reflect_bear_researcher.called assert reflector.reflect_bear_researcher.called
assert reflector.reflect_trader.called assert reflector.reflect_trader.called
def test_reflection_with_negative_returns(self, mock_llm, mock_memory, sample_state): def test_reflection_with_negative_returns(
self, mock_llm, mock_memory, sample_state
):
"""Test reflection with negative returns.""" """Test reflection with negative returns."""
reflector = Mock() reflector = Mock()
@ -175,10 +179,7 @@ class TestReflector:
decisions = ["BUY", "SELL", "HOLD"] decisions = ["BUY", "SELL", "HOLD"]
for decision in decisions: for decision in decisions:
state = { state = {"final_trade_decision": decision, "company_of_interest": "AAPL"}
"final_trade_decision": decision,
"company_of_interest": "AAPL"
}
returns_losses = {"return": 0.03, "loss": -0.01} returns_losses = {"return": 0.03, "loss": -0.01}
reflector.reflect_trader(state, returns_losses, mock_memory) reflector.reflect_trader(state, returns_losses, mock_memory)
@ -190,7 +191,9 @@ class TestReflector:
reflector = Mock() reflector = Mock()
# Simulate error in reflection # Simulate error in reflection
reflector.reflect_bull_researcher = Mock(side_effect=Exception("Reflection error")) reflector.reflect_bull_researcher = Mock(
side_effect=Exception("Reflection error")
)
with pytest.raises(Exception): with pytest.raises(Exception):
reflector.reflect_bull_researcher(sample_state, {}, mock_memory) reflector.reflect_bull_researcher(sample_state, {}, mock_memory)

View File

@ -12,7 +12,9 @@ class TestSignalProcessor:
mock_llm = Mock() mock_llm = Mock()
# Import with mocked dependencies to avoid pandas import # Import with mocked dependencies to avoid pandas import
with patch('sys.modules', {'pandas': Mock(), 'yfinance': Mock(), 'openai': Mock()}): with patch(
"sys.modules", {"pandas": Mock(), "yfinance": Mock(), "openai": Mock()}
):
# This would normally import SignalProcessor # This would normally import SignalProcessor
# from tradingagents.graph.signal_processing import SignalProcessor # from tradingagents.graph.signal_processing import SignalProcessor
# processor = SignalProcessor(mock_llm) # processor = SignalProcessor(mock_llm)

View File

@ -5,6 +5,7 @@ from unittest.mock import Mock, mock_open, patch
import pytest import pytest
from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.graph.trading_graph import TradingAgentsGraph
from .mock_toolkit_fix import patch_toolkit_in_test
class TestTradingAgentsGraph: class TestTradingAgentsGraph:
@ -26,9 +27,8 @@ class TestTradingAgentsGraph:
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
mock_llm = Mock() mock_llm = Mock()
mock_chat_openai.return_value = mock_llm mock_chat_openai.return_value = mock_llm
mock_toolkit_instance = Mock() mock_toolkit_instance = patch_toolkit_in_test(mock_toolkit)
mock_toolkit_instance.config = sample_config mock_toolkit_instance.config = sample_config
mock_toolkit.return_value = mock_toolkit_instance
# Execute # Execute
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
@ -56,7 +56,7 @@ class TestTradingAgentsGraph:
mock_llm = Mock() mock_llm = Mock()
mock_chat_openai.return_value = mock_llm mock_chat_openai.return_value = mock_llm
mock_toolkit_instance = Mock() mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"): with patch("tradingagents.dataflows.config.set_config"):
@ -79,7 +79,7 @@ class TestTradingAgentsGraph:
mock_llm = Mock() mock_llm = Mock()
mock_chat_anthropic.return_value = mock_llm mock_chat_anthropic.return_value = mock_llm
mock_toolkit_instance = Mock() mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"): with patch("tradingagents.dataflows.config.set_config"):
@ -102,7 +102,7 @@ class TestTradingAgentsGraph:
mock_llm = Mock() mock_llm = Mock()
mock_chat_google.return_value = mock_llm mock_chat_google.return_value = mock_llm
mock_toolkit_instance = Mock() mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"): with patch("tradingagents.dataflows.config.set_config"):
@ -121,7 +121,7 @@ class TestTradingAgentsGraph:
sample_config["project_dir"] = temp_data_dir sample_config["project_dir"] = temp_data_dir
sample_config["llm_provider"] = "unsupported" sample_config["llm_provider"] = "unsupported"
mock_toolkit_instance = Mock() mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance patch_toolkit_in_test(mock_toolkit)
with pytest.raises(ValueError, match="Unsupported LLM provider"): with pytest.raises(ValueError, match="Unsupported LLM provider"):
with patch("tradingagents.dataflows.config.set_config"): with patch("tradingagents.dataflows.config.set_config"):
@ -147,7 +147,7 @@ class TestTradingAgentsGraph:
mock_toolkit_instance.get_YFin_data = Mock() mock_toolkit_instance.get_YFin_data = Mock()
mock_toolkit_instance.get_stockstats_indicators_report_online = Mock() mock_toolkit_instance.get_stockstats_indicators_report_online = Mock()
mock_toolkit_instance.get_stockstats_indicators_report = Mock() mock_toolkit_instance.get_stockstats_indicators_report = Mock()
mock_toolkit.return_value = mock_toolkit_instance patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"): with patch("tradingagents.dataflows.config.set_config"):
@ -173,7 +173,7 @@ class TestTradingAgentsGraph:
mock_llm = Mock() mock_llm = Mock()
mock_chat_openai.return_value = mock_llm mock_chat_openai.return_value = mock_llm
mock_toolkit_instance = Mock() mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance patch_toolkit_in_test(mock_toolkit)
# Mock the graph and its invoke method # Mock the graph and its invoke method
mock_graph = Mock() mock_graph = Mock()
@ -240,7 +240,7 @@ class TestTradingAgentsGraph:
mock_llm = Mock() mock_llm = Mock()
mock_chat_openai.return_value = mock_llm mock_chat_openai.return_value = mock_llm
mock_toolkit_instance = Mock() mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance patch_toolkit_in_test(mock_toolkit)
# Mock the graph stream method for debug mode # Mock the graph stream method for debug mode
mock_graph = Mock() mock_graph = Mock()
@ -283,7 +283,7 @@ class TestTradingAgentsGraph:
mock_llm = Mock() mock_llm = Mock()
mock_chat_openai.return_value = mock_llm mock_chat_openai.return_value = mock_llm
mock_toolkit_instance = Mock() mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"): with patch("tradingagents.dataflows.config.set_config"):
@ -342,7 +342,7 @@ class TestTradingAgentsGraph:
mock_llm = Mock() mock_llm = Mock()
mock_chat_openai.return_value = mock_llm mock_chat_openai.return_value = mock_llm
mock_toolkit_instance = Mock() mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance patch_toolkit_in_test(mock_toolkit)
with ( with (
patch( patch(
@ -388,7 +388,7 @@ class TestTradingAgentsGraph:
mock_llm = Mock() mock_llm = Mock()
mock_chat_openai.return_value = mock_llm mock_chat_openai.return_value = mock_llm
mock_toolkit_instance = Mock() mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"): with patch("tradingagents.dataflows.config.set_config"):
@ -425,7 +425,7 @@ class TestTradingAgentsGraph:
mock_llm = Mock() mock_llm = Mock()
mock_chat_openai.return_value = mock_llm mock_chat_openai.return_value = mock_llm
mock_toolkit_instance = Mock() mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance patch_toolkit_in_test(mock_toolkit)
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):
with patch("tradingagents.dataflows.config.set_config"): with patch("tradingagents.dataflows.config.set_config"):
@ -447,7 +447,7 @@ class TestTradingAgentsGraphErrorHandling:
"""Test handling of invalid configuration.""" """Test handling of invalid configuration."""
invalid_config = {"invalid_key": "invalid_value"} invalid_config = {"invalid_key": "invalid_value"}
mock_toolkit_instance = Mock() mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance patch_toolkit_in_test(mock_toolkit)
# This should still work as the class should use defaults for missing keys # This should still work as the class should use defaults for missing keys
with patch("tradingagents.dataflows.config.set_config"): with patch("tradingagents.dataflows.config.set_config"):
@ -469,7 +469,7 @@ class TestTradingAgentsGraphErrorHandling:
mock_llm = Mock() mock_llm = Mock()
mock_chat_openai.return_value = mock_llm mock_chat_openai.return_value = mock_llm
mock_toolkit_instance = Mock() mock_toolkit_instance = Mock()
mock_toolkit.return_value = mock_toolkit_instance patch_toolkit_in_test(mock_toolkit)
# Should handle directory creation gracefully or raise appropriate error # Should handle directory creation gracefully or raise appropriate error
with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"):