272 lines
9.3 KiB
Python
272 lines
9.3 KiB
Python
"""Extended unit tests for market analyst to improve coverage."""
|
|
|
|
from unittest.mock import Mock
|
|
import pytest
|
|
|
|
|
|
class TestMarketAnalystExtended:
|
|
"""Extended test suite for market analyst functionality."""
|
|
|
|
@pytest.fixture
|
|
def mock_llm_extended(self):
|
|
"""Extended mock LLM with more functionality."""
|
|
mock = Mock()
|
|
mock.model_name = "test-model"
|
|
|
|
# Create a mock chain
|
|
mock_chain = Mock()
|
|
mock_chain.invoke = Mock()
|
|
mock.bind_tools = Mock(return_value=mock_chain)
|
|
|
|
return mock
|
|
|
|
@pytest.fixture
|
|
def mock_toolkit_extended(self):
|
|
"""Extended mock toolkit with all methods."""
|
|
toolkit = Mock()
|
|
toolkit.config = {"online_tools": False}
|
|
|
|
# Create mock functions with proper attributes
|
|
def mock_yfin():
|
|
return "YFin data"
|
|
|
|
def mock_stockstats():
|
|
return "Stockstats data"
|
|
|
|
toolkit.get_YFin_data = Mock(side_effect=mock_yfin)
|
|
toolkit.get_YFin_data.__name__ = "get_YFin_data"
|
|
toolkit.get_YFin_data.name = "get_YFin_data"
|
|
|
|
toolkit.get_stockstats_indicators_report = Mock(side_effect=mock_stockstats)
|
|
toolkit.get_stockstats_indicators_report.__name__ = (
|
|
"get_stockstats_indicators_report"
|
|
)
|
|
toolkit.get_stockstats_indicators_report.name = (
|
|
"get_stockstats_indicators_report"
|
|
)
|
|
|
|
# Online versions
|
|
toolkit.get_YFin_data_online = Mock(side_effect=mock_yfin)
|
|
toolkit.get_YFin_data_online.__name__ = "get_YFin_data_online"
|
|
toolkit.get_YFin_data_online.name = "get_YFin_data_online"
|
|
|
|
toolkit.get_stockstats_indicators_report_online = Mock(
|
|
side_effect=mock_stockstats
|
|
)
|
|
toolkit.get_stockstats_indicators_report_online.__name__ = (
|
|
"get_stockstats_indicators_report_online"
|
|
)
|
|
toolkit.get_stockstats_indicators_report_online.name = (
|
|
"get_stockstats_indicators_report_online"
|
|
)
|
|
|
|
return toolkit
|
|
|
|
def test_market_analyst_system_message(
|
|
self, mock_llm_extended, mock_toolkit_extended
|
|
):
|
|
"""Test that system message is properly formatted."""
|
|
# This would normally import and test the actual function
|
|
# For now, we test the mock behavior
|
|
|
|
state = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": "2024-05-10",
|
|
"messages": [],
|
|
}
|
|
|
|
# Simulate creating analyst
|
|
mock_analyst = Mock()
|
|
mock_analyst.return_value = {"messages": [], "market_report": "Test report"}
|
|
|
|
result = mock_analyst(state)
|
|
assert "market_report" in result
|
|
assert "messages" in result
|
|
|
|
def test_market_analyst_with_multiple_indicators(
|
|
self, mock_llm_extended, mock_toolkit_extended
|
|
):
|
|
"""Test analyst with multiple technical indicators."""
|
|
state = {
|
|
"company_of_interest": "TSLA",
|
|
"trade_date": "2024-05-15",
|
|
"messages": [],
|
|
}
|
|
|
|
# Mock result with multiple indicators
|
|
mock_result = Mock()
|
|
mock_result.content = """
|
|
Analysis with multiple indicators:
|
|
- RSI: 65 (neutral)
|
|
- MACD: Bullish crossover
|
|
- Bollinger Bands: Price near upper band
|
|
- 50 SMA: Upward trend
|
|
- Volume: Above average
|
|
"""
|
|
mock_result.tool_calls = []
|
|
|
|
mock_llm_extended.bind_tools.return_value.invoke.return_value = mock_result
|
|
|
|
# Create mock analyst function
|
|
def mock_analyst(state):
|
|
return {"messages": [mock_result], "market_report": mock_result.content}
|
|
|
|
result = mock_analyst(state)
|
|
assert "RSI" in result["market_report"]
|
|
assert "MACD" in result["market_report"]
|
|
assert "Bollinger" in result["market_report"]
|
|
|
|
def test_market_analyst_error_handling(
|
|
self, mock_llm_extended, mock_toolkit_extended
|
|
):
|
|
"""Test error handling in market analyst."""
|
|
state = {
|
|
"company_of_interest": "INVALID",
|
|
"trade_date": "2024-05-10",
|
|
"messages": [],
|
|
}
|
|
|
|
# Mock error scenario
|
|
mock_llm_extended.bind_tools.return_value.invoke.side_effect = Exception(
|
|
"API Error"
|
|
)
|
|
|
|
# Create analyst with error handling
|
|
def mock_analyst_with_error_handling(state):
|
|
try:
|
|
# Would call actual analyst here
|
|
raise Exception("API Error")
|
|
except Exception:
|
|
return {"messages": [], "market_report": "Error analyzing market data"}
|
|
|
|
result = mock_analyst_with_error_handling(state)
|
|
assert result["market_report"] == "Error analyzing market data"
|
|
|
|
def test_market_analyst_date_formatting(
|
|
self, mock_llm_extended, mock_toolkit_extended
|
|
):
|
|
"""Test various date formats in market analyst."""
|
|
test_dates = [
|
|
"2024-01-01",
|
|
"2024-12-31",
|
|
"2024-05-15",
|
|
]
|
|
|
|
for date in test_dates:
|
|
state = {"company_of_interest": "AAPL", "trade_date": date, "messages": []}
|
|
|
|
mock_result = Mock()
|
|
mock_result.content = f"Analysis for {date}"
|
|
mock_result.tool_calls = []
|
|
|
|
def mock_analyst(state):
|
|
return {
|
|
"messages": [mock_result],
|
|
"market_report": f"Analysis for {state['trade_date']}",
|
|
}
|
|
|
|
result = mock_analyst(state)
|
|
assert date in result["market_report"]
|
|
|
|
def test_market_analyst_ticker_variations(
|
|
self, mock_llm_extended, mock_toolkit_extended
|
|
):
|
|
"""Test analyst with various ticker symbols."""
|
|
tickers = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA"]
|
|
|
|
for ticker in tickers:
|
|
state = {
|
|
"company_of_interest": ticker,
|
|
"trade_date": "2024-05-10",
|
|
"messages": [],
|
|
}
|
|
|
|
mock_result = Mock()
|
|
mock_result.content = f"Analysis for {ticker}"
|
|
mock_result.tool_calls = []
|
|
|
|
def mock_analyst(state):
|
|
return {
|
|
"messages": [mock_result],
|
|
"market_report": f"Analysis for {state['company_of_interest']}",
|
|
}
|
|
|
|
result = mock_analyst(state)
|
|
assert ticker in result["market_report"]
|
|
|
|
def test_market_analyst_online_vs_offline(self, mock_llm_extended):
|
|
"""Test analyst behavior with online vs offline tools."""
|
|
# Test offline configuration
|
|
toolkit_offline = Mock()
|
|
toolkit_offline.config = {"online_tools": False}
|
|
|
|
def mock_offline():
|
|
return "Offline data"
|
|
|
|
toolkit_offline.get_YFin_data = Mock(side_effect=mock_offline)
|
|
toolkit_offline.get_YFin_data.__name__ = "get_YFin_data"
|
|
|
|
# Test online configuration
|
|
toolkit_online = Mock()
|
|
toolkit_online.config = {"online_tools": True}
|
|
|
|
def mock_online():
|
|
return "Online data"
|
|
|
|
toolkit_online.get_YFin_data_online = Mock(side_effect=mock_online)
|
|
toolkit_online.get_YFin_data_online.__name__ = "get_YFin_data_online"
|
|
|
|
# Both should work correctly
|
|
assert toolkit_offline.config["online_tools"] is False
|
|
assert toolkit_online.config["online_tools"] is True
|
|
assert toolkit_offline.get_YFin_data() == "Offline data"
|
|
assert toolkit_online.get_YFin_data_online() == "Online data"
|
|
|
|
def test_market_analyst_empty_state(self, mock_llm_extended, mock_toolkit_extended):
|
|
"""Test analyst with minimal/empty state."""
|
|
state = {"company_of_interest": "", "trade_date": "", "messages": []}
|
|
|
|
mock_result = Mock()
|
|
mock_result.content = "No data available"
|
|
mock_result.tool_calls = []
|
|
|
|
def mock_analyst(state):
|
|
if not state["company_of_interest"] or not state["trade_date"]:
|
|
return {"messages": [], "market_report": "No data available"}
|
|
return {"messages": [mock_result], "market_report": mock_result.content}
|
|
|
|
result = mock_analyst(state)
|
|
assert result["market_report"] == "No data available"
|
|
|
|
def test_market_analyst_tool_calls_tracking(
|
|
self, mock_llm_extended, mock_toolkit_extended
|
|
):
|
|
"""Test tracking of tool calls in market analyst."""
|
|
state = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": "2024-05-10",
|
|
"messages": [],
|
|
}
|
|
|
|
# Mock result with tool calls
|
|
mock_tool_call = Mock()
|
|
mock_tool_call.function.name = "get_YFin_data"
|
|
mock_tool_call.function.arguments = '{"ticker": "AAPL"}'
|
|
|
|
mock_result = Mock()
|
|
mock_result.content = ""
|
|
mock_result.tool_calls = [mock_tool_call]
|
|
|
|
mock_llm_extended.bind_tools.return_value.invoke.return_value = mock_result
|
|
|
|
def mock_analyst(state):
|
|
result = mock_llm_extended.bind_tools([]).invoke(state["messages"])
|
|
# When tool_calls exist, market_report should be empty
|
|
report = "" if result.tool_calls else result.content
|
|
return {"messages": [result], "market_report": report}
|
|
|
|
result = mock_analyst(state)
|
|
assert result["market_report"] == "" # Empty when tool calls exist
|
|
assert len(result["messages"]) == 1
|
|
assert result["messages"][0].tool_calls == [mock_tool_call]
|