TradingAgents/tests/dataflows/test_interface.py

309 lines
12 KiB
Python

import pytest
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime, timedelta
from tradingagents.dataflows.interface import (
parse_lookback_period,
get_bulk_news,
get_category_for_method,
get_vendor,
route_to_vendor,
TOOLS_CATEGORIES,
VENDOR_METHODS,
)
from tradingagents.agents.discovery import NewsArticle
class TestParseLookbackPeriod:
"""Test suite for parse_lookback_period function."""
def test_parse_lookback_1h(self):
"""Test parsing '1h' lookback period."""
assert parse_lookback_period("1h") == 1
def test_parse_lookback_6h(self):
"""Test parsing '6h' lookback period."""
assert parse_lookback_period("6h") == 6
def test_parse_lookback_24h(self):
"""Test parsing '24h' lookback period."""
assert parse_lookback_period("24h") == 24
def test_parse_lookback_7d(self):
"""Test parsing '7d' lookback period."""
assert parse_lookback_period("7d") == 168 # 7 * 24
def test_parse_lookback_case_insensitive(self):
"""Test that parsing is case insensitive."""
assert parse_lookback_period("1H") == 1
assert parse_lookback_period("6H") == 6
assert parse_lookback_period("24H") == 24
assert parse_lookback_period("7D") == 168
def test_parse_lookback_with_spaces(self):
"""Test parsing with leading/trailing spaces."""
assert parse_lookback_period(" 1h ") == 1
assert parse_lookback_period(" 24h ") == 24
def test_parse_lookback_invalid_value(self):
"""Test that invalid values raise ValueError."""
with pytest.raises(ValueError, match="Invalid lookback period"):
parse_lookback_period("invalid")
with pytest.raises(ValueError):
parse_lookback_period("10h")
with pytest.raises(ValueError):
parse_lookback_period("2d")
class TestGetCategoryForMethod:
"""Test suite for get_category_for_method function."""
def test_get_category_core_stock_apis(self):
"""Test categorization of core stock API methods."""
assert get_category_for_method("get_stock_data") == "core_stock_apis"
def test_get_category_technical_indicators(self):
"""Test categorization of technical indicator methods."""
assert get_category_for_method("get_indicators") == "technical_indicators"
def test_get_category_fundamental_data(self):
"""Test categorization of fundamental data methods."""
assert get_category_for_method("get_fundamentals") == "fundamental_data"
assert get_category_for_method("get_balance_sheet") == "fundamental_data"
assert get_category_for_method("get_cashflow") == "fundamental_data"
assert get_category_for_method("get_income_statement") == "fundamental_data"
def test_get_category_news_data(self):
"""Test categorization of news data methods."""
assert get_category_for_method("get_news") == "news_data"
assert get_category_for_method("get_global_news") == "news_data"
assert get_category_for_method("get_insider_sentiment") == "news_data"
assert get_category_for_method("get_insider_transactions") == "news_data"
assert get_category_for_method("get_bulk_news") == "news_data"
def test_get_category_invalid_method(self):
"""Test that invalid methods raise ValueError."""
with pytest.raises(ValueError, match="not found in any category"):
get_category_for_method("nonexistent_method")
class TestGetBulkNews:
"""Test suite for get_bulk_news function."""
@patch('tradingagents.dataflows.interface._fetch_bulk_news_from_vendor')
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
def test_get_bulk_news_default_period(self, mock_convert, mock_fetch):
"""Test get_bulk_news with default lookback period."""
mock_fetch.return_value = []
mock_convert.return_value = []
result = get_bulk_news()
mock_fetch.assert_called_once_with("24h")
assert isinstance(result, list)
@patch('tradingagents.dataflows.interface._fetch_bulk_news_from_vendor')
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
def test_get_bulk_news_custom_period(self, mock_convert, mock_fetch):
"""Test get_bulk_news with custom lookback period."""
mock_fetch.return_value = []
mock_convert.return_value = []
result = get_bulk_news("6h")
mock_fetch.assert_called_once_with("6h")
@patch('tradingagents.dataflows.interface._fetch_bulk_news_from_vendor')
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
def test_get_bulk_news_caching(self, mock_convert, mock_fetch):
"""Test that results are cached."""
mock_raw_articles = [
{
"title": "Test Article",
"source": "Source",
"url": "https://example.com",
"published_at": datetime.now().isoformat(),
"content_snippet": "Content",
}
]
mock_article = NewsArticle(
title="Test Article",
source="Source",
url="https://example.com",
published_at=datetime.now(),
content_snippet="Content",
ticker_mentions=[],
)
mock_fetch.return_value = mock_raw_articles
mock_convert.return_value = [mock_article]
# First call should fetch
result1 = get_bulk_news("24h")
call_count_1 = mock_fetch.call_count
# Second call within cache TTL should use cache
result2 = get_bulk_news("24h")
call_count_2 = mock_fetch.call_count
# Fetch should not be called again if cache is working
# (Note: actual caching behavior depends on implementation)
assert isinstance(result1, list)
assert isinstance(result2, list)
@patch('tradingagents.dataflows.interface._fetch_bulk_news_from_vendor')
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
def test_get_bulk_news_converts_articles(self, mock_convert, mock_fetch):
"""Test that raw articles are converted to NewsArticle objects."""
mock_raw = [{"title": "Test"}]
mock_articles = [Mock(spec=NewsArticle)]
mock_fetch.return_value = mock_raw
mock_convert.return_value = mock_articles
result = get_bulk_news("24h")
mock_convert.assert_called_once_with(mock_raw)
assert result == mock_articles
class TestRouteToVendor:
"""Test suite for route_to_vendor function."""
@patch('tradingagents.dataflows.interface.get_vendor')
@patch('tradingagents.dataflows.interface.get_category_for_method')
def test_route_to_vendor_basic(self, mock_get_category, mock_get_vendor):
"""Test basic vendor routing."""
mock_get_category.return_value = "core_stock_apis"
mock_get_vendor.return_value = "yfinance"
# Mock the vendor function
with patch.dict(VENDOR_METHODS, {"get_stock_data": {"yfinance": Mock(return_value="test_data")}}):
result = route_to_vendor("get_stock_data", "AAPL", "2024-01-01")
assert result == "test_data"
@patch('tradingagents.dataflows.interface.get_vendor')
@patch('tradingagents.dataflows.interface.get_category_for_method')
def test_route_to_vendor_fallback(self, mock_get_category, mock_get_vendor):
"""Test vendor fallback when primary fails."""
mock_get_category.return_value = "news_data"
mock_get_vendor.return_value = "alpha_vantage"
# Mock primary vendor to fail, secondary to succeed
primary_mock = Mock(side_effect=Exception("Primary failed"))
secondary_mock = Mock(return_value="fallback_data")
with patch.dict(VENDOR_METHODS, {
"get_news": {
"alpha_vantage": primary_mock,
"openai": secondary_mock,
}
}):
result = route_to_vendor("get_news", "AAPL", "2024-01-01", "2024-01-31")
assert result == "fallback_data"
assert primary_mock.called
assert secondary_mock.called
@patch('tradingagents.dataflows.interface.get_vendor')
@patch('tradingagents.dataflows.interface.get_category_for_method')
def test_route_to_vendor_all_fail(self, mock_get_category, mock_get_vendor):
"""Test that RuntimeError is raised when all vendors fail."""
mock_get_category.return_value = "news_data"
mock_get_vendor.return_value = "alpha_vantage"
# All vendors fail
failing_mock = Mock(side_effect=Exception("Failed"))
with patch.dict(VENDOR_METHODS, {
"get_news": {
"alpha_vantage": failing_mock,
"openai": failing_mock,
}
}):
with pytest.raises(RuntimeError, match="All vendor implementations failed"):
route_to_vendor("get_news", "AAPL", "2024-01-01", "2024-01-31")
@patch('tradingagents.dataflows.interface.get_vendor')
@patch('tradingagents.dataflows.interface.get_category_for_method')
def test_route_to_vendor_multiple_results(self, mock_get_category, mock_get_vendor):
"""Test handling of multiple vendor implementations."""
mock_get_category.return_value = "news_data"
mock_get_vendor.return_value = "local"
# Local vendor has multiple implementations
impl1 = Mock(return_value="result1")
impl2 = Mock(return_value="result2")
with patch.dict(VENDOR_METHODS, {
"get_news": {
"local": [impl1, impl2],
}
}):
result = route_to_vendor("get_news", "AAPL", "2024-01-01", "2024-01-31")
# Should combine multiple results
assert isinstance(result, str)
assert impl1.called
assert impl2.called
def test_route_to_vendor_unsupported_method(self):
"""Test that ValueError is raised for unsupported methods."""
with pytest.raises(ValueError, match="not found in any category"):
route_to_vendor("nonexistent_method", "arg1")
class TestConvertToNewsArticles:
"""Test suite for _convert_to_news_articles function."""
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
def test_convert_empty_list(self, mock_convert):
"""Test converting empty article list."""
mock_convert.return_value = []
from tradingagents.dataflows.interface import _convert_to_news_articles
result = _convert_to_news_articles([])
assert result == []
@patch('tradingagents.dataflows.interface.NewsArticle')
def test_convert_valid_articles(self, mock_news_article):
"""Test converting valid raw articles."""
from tradingagents.dataflows.interface import _convert_to_news_articles
raw_articles = [
{
"title": "Article 1",
"source": "Source 1",
"url": "https://example.com/1",
"published_at": datetime(2024, 1, 15).isoformat(),
"content_snippet": "Content 1",
}
]
result = _convert_to_news_articles(raw_articles)
# Should attempt to create NewsArticle
assert isinstance(result, list)
def test_convert_invalid_date_format(self):
"""Test handling of invalid date formats."""
from tradingagents.dataflows.interface import _convert_to_news_articles
raw_articles = [
{
"title": "Article",
"source": "Source",
"url": "https://example.com",
"published_at": "invalid_date",
"content_snippet": "Content",
}
]
result = _convert_to_news_articles(raw_articles)
# Should handle gracefully
assert isinstance(result, list)