TradingAgents/tests/conftest.py

340 lines
11 KiB
Python

"""Pytest configuration and shared fixtures for TradingAgents tests."""
import os
import tempfile
from unittest.mock import Mock
import pytest
# Set test environment variables before importing DEFAULT_CONFIG
# This prevents hanging during config loading due to missing API keys
os.environ.setdefault("OPENAI_API_KEY", "test-key")
os.environ.setdefault("FINNHUB_API_KEY", "test-key")
os.environ.setdefault("REDDIT_CLIENT_ID", "test-id")
os.environ.setdefault("REDDIT_CLIENT_SECRET", "test-secret")
from tradingagents.default_config import DEFAULT_CONFIG
@pytest.fixture
def sample_config():
"""Provide a test configuration."""
config = DEFAULT_CONFIG.copy()
config.update(
{
"online_tools": False, # Use offline tools for testing
"max_debate_rounds": 1, # Limit rounds for faster tests
"llm_provider": "openai",
"deep_think_llm": "gpt-4o-mini",
"quick_think_llm": "gpt-4o-mini",
"project_dir": "/tmp/test_tradingagents",
},
)
return config
class MockResult:
"""Mock result that always has proper tool_calls attribute."""
def __init__(self, content="Test response", tool_calls=None):
self.content = content
self.tool_calls = tool_calls if tool_calls is not None else []
@pytest.fixture
def mock_llm():
"""Mock LLM for testing."""
mock = Mock()
mock.model_name = "test-model"
# Create a default mock result with proper tool_calls
default_result = MockResult()
# Simple approach: create a mock that will be returned by any chain operation
chain_result = Mock()
chain_result.return_value = default_result
# Mock the bind_tools to return a mock that handles piping
bound_mock = Mock()
bound_mock.invoke = Mock(return_value=default_result)
# Handle the pipe operation by returning a mock that also returns our result
def handle_pipe(other):
pipe_result = Mock()
pipe_result.invoke = Mock(return_value=default_result)
return pipe_result
bound_mock.__ror__ = handle_pipe
mock.bind_tools.return_value = bound_mock
# Keep direct invoke for backward compatibility
mock.invoke.return_value = default_result
return mock
@pytest.fixture
def sample_ticker():
"""Sample stock ticker for testing."""
return "AAPL"
@pytest.fixture
def sample_trade_date():
"""Sample trade date for testing."""
return "2024-05-10"
@pytest.fixture
def sample_agent_state():
"""Sample agent state for testing."""
return {
"company_of_interest": "AAPL",
"trade_date": "2024-05-10",
"messages": [],
"market_report": "",
"sentiment_report": "",
"news_report": "",
"fundamentals_report": "",
"investment_debate_state": {
"bull_history": [],
"bear_history": [],
"history": [],
"current_response": "",
"judge_decision": "",
},
"trader_investment_plan": "",
"risk_debate_state": {
"risky_history": [],
"safe_history": [],
"neutral_history": [],
"history": [],
"judge_decision": "",
},
"investment_plan": "",
"final_trade_decision": "",
}
@pytest.fixture
def temp_data_dir():
"""Create a temporary directory for test data."""
with tempfile.TemporaryDirectory() as temp_dir:
yield temp_dir
@pytest.fixture
def mock_toolkit():
"""Mock toolkit with all necessary methods."""
toolkit = Mock()
toolkit.config = {"online_tools": False}
# Create mock functions with proper __name__ attributes
def mock_get_YFin_data():
return "Mock YFin data"
def mock_get_YFin_data_online():
return "Mock YFin data online"
def mock_get_stockstats_indicators_report():
return "Mock stockstats report"
def mock_get_stockstats_indicators_report_online():
return "Mock stockstats report online"
def mock_get_reddit_stock_info():
return "Mock reddit stock info"
def mock_get_stock_news_openai():
return "Mock stock news"
def mock_get_finnhub_news():
return "Mock finnhub news"
def mock_get_reddit_news():
return "Mock reddit news"
def mock_get_global_news_openai():
return "Mock global news"
def mock_get_google_news():
return "Mock google news"
def mock_get_fundamentals_openai():
return "Mock fundamentals"
def mock_get_finnhub_company_insider_sentiment():
return "Mock insider sentiment"
def mock_get_finnhub_company_insider_transactions():
return "Mock insider transactions"
def mock_get_simfin_balance_sheet():
return "Mock balance sheet"
def mock_get_simfin_cashflow():
return "Mock cashflow"
def mock_get_simfin_income_stmt():
return "Mock income statement"
# Assign the mock functions to the toolkit
toolkit.get_YFin_data = Mock(side_effect=mock_get_YFin_data)
toolkit.get_YFin_data.name = "get_YFin_data"
toolkit.get_YFin_data.__name__ = "get_YFin_data"
toolkit.get_YFin_data_online = Mock(side_effect=mock_get_YFin_data_online)
toolkit.get_YFin_data_online.name = "get_YFin_data_online"
toolkit.get_YFin_data_online.__name__ = "get_YFin_data_online"
toolkit.get_stockstats_indicators_report = Mock(
side_effect=mock_get_stockstats_indicators_report
)
toolkit.get_stockstats_indicators_report.name = "get_stockstats_indicators_report"
toolkit.get_stockstats_indicators_report.__name__ = (
"get_stockstats_indicators_report"
)
toolkit.get_stockstats_indicators_report_online = Mock(
side_effect=mock_get_stockstats_indicators_report_online
)
toolkit.get_stockstats_indicators_report_online.name = (
"get_stockstats_indicators_report_online"
)
toolkit.get_stockstats_indicators_report_online.__name__ = (
"get_stockstats_indicators_report_online"
)
toolkit.get_reddit_stock_info = Mock(side_effect=mock_get_reddit_stock_info)
toolkit.get_reddit_stock_info.name = "get_reddit_stock_info"
toolkit.get_reddit_stock_info.__name__ = "get_reddit_stock_info"
toolkit.get_stock_news_openai = Mock(side_effect=mock_get_stock_news_openai)
toolkit.get_stock_news_openai.name = "get_stock_news_openai"
toolkit.get_stock_news_openai.__name__ = "get_stock_news_openai"
toolkit.get_finnhub_news = Mock(side_effect=mock_get_finnhub_news)
toolkit.get_finnhub_news.name = "get_finnhub_news"
toolkit.get_finnhub_news.__name__ = "get_finnhub_news"
toolkit.get_reddit_news = Mock(side_effect=mock_get_reddit_news)
toolkit.get_reddit_news.name = "get_reddit_news"
toolkit.get_reddit_news.__name__ = "get_reddit_news"
toolkit.get_global_news_openai = Mock(side_effect=mock_get_global_news_openai)
toolkit.get_global_news_openai.name = "get_global_news_openai"
toolkit.get_global_news_openai.__name__ = "get_global_news_openai"
toolkit.get_google_news = Mock(side_effect=mock_get_google_news)
toolkit.get_google_news.name = "get_google_news"
toolkit.get_google_news.__name__ = "get_google_news"
toolkit.get_fundamentals_openai = Mock(side_effect=mock_get_fundamentals_openai)
toolkit.get_fundamentals_openai.name = "get_fundamentals_openai"
toolkit.get_fundamentals_openai.__name__ = "get_fundamentals_openai"
toolkit.get_finnhub_company_insider_sentiment = Mock(
side_effect=mock_get_finnhub_company_insider_sentiment
)
toolkit.get_finnhub_company_insider_sentiment.name = (
"get_finnhub_company_insider_sentiment"
)
toolkit.get_finnhub_company_insider_sentiment.__name__ = (
"get_finnhub_company_insider_sentiment"
)
toolkit.get_finnhub_company_insider_transactions = Mock(
side_effect=mock_get_finnhub_company_insider_transactions
)
toolkit.get_finnhub_company_insider_transactions.name = (
"get_finnhub_company_insider_transactions"
)
toolkit.get_finnhub_company_insider_transactions.__name__ = (
"get_finnhub_company_insider_transactions"
)
toolkit.get_simfin_balance_sheet = Mock(side_effect=mock_get_simfin_balance_sheet)
toolkit.get_simfin_balance_sheet.name = "get_simfin_balance_sheet"
toolkit.get_simfin_balance_sheet.__name__ = "get_simfin_balance_sheet"
toolkit.get_simfin_cashflow = Mock(side_effect=mock_get_simfin_cashflow)
toolkit.get_simfin_cashflow.name = "get_simfin_cashflow"
toolkit.get_simfin_cashflow.__name__ = "get_simfin_cashflow"
toolkit.get_simfin_income_stmt = Mock(side_effect=mock_get_simfin_income_stmt)
toolkit.get_simfin_income_stmt.name = "get_simfin_income_stmt"
toolkit.get_simfin_income_stmt.__name__ = "get_simfin_income_stmt"
return toolkit
@pytest.fixture
def sample_market_data():
"""Sample market data for testing."""
return {
"close": 150.0,
"high": 155.0,
"low": 148.0,
"volume": 1000000,
"date": "2024-05-10",
}
@pytest.fixture
def sample_financial_data():
"""Sample financial data for testing."""
return {
"revenue": 100000000,
"net_income": 20000000,
"assets": 500000000,
"liabilities": 200000000,
"period": "annual",
"date": "2024-05-10",
}
@pytest.fixture(autouse=True)
def setup_test_environment(temp_data_dir):
"""Set up test directories."""
# Create test data directories
data_cache_dir = os.path.join(temp_data_dir, "dataflows", "data_cache")
os.makedirs(data_cache_dir, exist_ok=True)
finnhub_dir = os.path.join(data_cache_dir, "finnhub_data")
os.makedirs(finnhub_dir, exist_ok=True)
# Create subdirectories for different data types
for data_type in ["news_data", "insider_trans", "SEC_filings", "fin_as_reported"]:
os.makedirs(os.path.join(finnhub_dir, data_type), exist_ok=True)
@pytest.fixture
def mock_memory():
"""Mock financial situation memory."""
memory = Mock()
memory.add_memory = Mock()
memory.get_memory = Mock(return_value="")
memory.clear_memory = Mock()
return memory
# Pytest configuration
def pytest_configure(config):
"""Configure pytest with custom markers."""
config.addinivalue_line(
"markers",
"integration: mark test as integration test (slow)",
)
config.addinivalue_line("markers", "unit: mark test as unit test (fast)")
config.addinivalue_line("markers", "api: mark test as requiring API access")
def pytest_collection_modifyitems(config, items):
"""Modify test collection to add markers based on location."""
for item in items:
# Add unit marker to tests in unit/ directory
if "unit" in str(item.fspath):
item.add_marker(pytest.mark.unit)
# Add integration marker to tests in integration/ directory
elif "integration" in str(item.fspath):
item.add_marker(pytest.mark.integration)