309 lines
12 KiB
Python
309 lines
12 KiB
Python
import pytest
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
from datetime import datetime, timedelta
|
|
from tradingagents.dataflows.interface import (
|
|
parse_lookback_period,
|
|
get_bulk_news,
|
|
get_category_for_method,
|
|
get_vendor,
|
|
route_to_vendor,
|
|
TOOLS_CATEGORIES,
|
|
VENDOR_METHODS,
|
|
)
|
|
from tradingagents.agents.discovery import NewsArticle
|
|
|
|
|
|
class TestParseLookbackPeriod:
|
|
"""Test suite for parse_lookback_period function."""
|
|
|
|
def test_parse_lookback_1h(self):
|
|
"""Test parsing '1h' lookback period."""
|
|
assert parse_lookback_period("1h") == 1
|
|
|
|
def test_parse_lookback_6h(self):
|
|
"""Test parsing '6h' lookback period."""
|
|
assert parse_lookback_period("6h") == 6
|
|
|
|
def test_parse_lookback_24h(self):
|
|
"""Test parsing '24h' lookback period."""
|
|
assert parse_lookback_period("24h") == 24
|
|
|
|
def test_parse_lookback_7d(self):
|
|
"""Test parsing '7d' lookback period."""
|
|
assert parse_lookback_period("7d") == 168 # 7 * 24
|
|
|
|
def test_parse_lookback_case_insensitive(self):
|
|
"""Test that parsing is case insensitive."""
|
|
assert parse_lookback_period("1H") == 1
|
|
assert parse_lookback_period("6H") == 6
|
|
assert parse_lookback_period("24H") == 24
|
|
assert parse_lookback_period("7D") == 168
|
|
|
|
def test_parse_lookback_with_spaces(self):
|
|
"""Test parsing with leading/trailing spaces."""
|
|
assert parse_lookback_period(" 1h ") == 1
|
|
assert parse_lookback_period(" 24h ") == 24
|
|
|
|
def test_parse_lookback_invalid_value(self):
|
|
"""Test that invalid values raise ValueError."""
|
|
with pytest.raises(ValueError, match="Invalid lookback period"):
|
|
parse_lookback_period("invalid")
|
|
|
|
with pytest.raises(ValueError):
|
|
parse_lookback_period("10h")
|
|
|
|
with pytest.raises(ValueError):
|
|
parse_lookback_period("2d")
|
|
|
|
|
|
class TestGetCategoryForMethod:
|
|
"""Test suite for get_category_for_method function."""
|
|
|
|
def test_get_category_core_stock_apis(self):
|
|
"""Test categorization of core stock API methods."""
|
|
assert get_category_for_method("get_stock_data") == "core_stock_apis"
|
|
|
|
def test_get_category_technical_indicators(self):
|
|
"""Test categorization of technical indicator methods."""
|
|
assert get_category_for_method("get_indicators") == "technical_indicators"
|
|
|
|
def test_get_category_fundamental_data(self):
|
|
"""Test categorization of fundamental data methods."""
|
|
assert get_category_for_method("get_fundamentals") == "fundamental_data"
|
|
assert get_category_for_method("get_balance_sheet") == "fundamental_data"
|
|
assert get_category_for_method("get_cashflow") == "fundamental_data"
|
|
assert get_category_for_method("get_income_statement") == "fundamental_data"
|
|
|
|
def test_get_category_news_data(self):
|
|
"""Test categorization of news data methods."""
|
|
assert get_category_for_method("get_news") == "news_data"
|
|
assert get_category_for_method("get_global_news") == "news_data"
|
|
assert get_category_for_method("get_insider_sentiment") == "news_data"
|
|
assert get_category_for_method("get_insider_transactions") == "news_data"
|
|
assert get_category_for_method("get_bulk_news") == "news_data"
|
|
|
|
def test_get_category_invalid_method(self):
|
|
"""Test that invalid methods raise ValueError."""
|
|
with pytest.raises(ValueError, match="not found in any category"):
|
|
get_category_for_method("nonexistent_method")
|
|
|
|
|
|
class TestGetBulkNews:
|
|
"""Test suite for get_bulk_news function."""
|
|
|
|
@patch('tradingagents.dataflows.interface._fetch_bulk_news_from_vendor')
|
|
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
|
|
def test_get_bulk_news_default_period(self, mock_convert, mock_fetch):
|
|
"""Test get_bulk_news with default lookback period."""
|
|
mock_fetch.return_value = []
|
|
mock_convert.return_value = []
|
|
|
|
result = get_bulk_news()
|
|
|
|
mock_fetch.assert_called_once_with("24h")
|
|
assert isinstance(result, list)
|
|
|
|
@patch('tradingagents.dataflows.interface._fetch_bulk_news_from_vendor')
|
|
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
|
|
def test_get_bulk_news_custom_period(self, mock_convert, mock_fetch):
|
|
"""Test get_bulk_news with custom lookback period."""
|
|
mock_fetch.return_value = []
|
|
mock_convert.return_value = []
|
|
|
|
result = get_bulk_news("6h")
|
|
|
|
mock_fetch.assert_called_once_with("6h")
|
|
|
|
@patch('tradingagents.dataflows.interface._fetch_bulk_news_from_vendor')
|
|
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
|
|
def test_get_bulk_news_caching(self, mock_convert, mock_fetch):
|
|
"""Test that results are cached."""
|
|
mock_raw_articles = [
|
|
{
|
|
"title": "Test Article",
|
|
"source": "Source",
|
|
"url": "https://example.com",
|
|
"published_at": datetime.now().isoformat(),
|
|
"content_snippet": "Content",
|
|
}
|
|
]
|
|
|
|
mock_article = NewsArticle(
|
|
title="Test Article",
|
|
source="Source",
|
|
url="https://example.com",
|
|
published_at=datetime.now(),
|
|
content_snippet="Content",
|
|
ticker_mentions=[],
|
|
)
|
|
|
|
mock_fetch.return_value = mock_raw_articles
|
|
mock_convert.return_value = [mock_article]
|
|
|
|
# First call should fetch
|
|
result1 = get_bulk_news("24h")
|
|
call_count_1 = mock_fetch.call_count
|
|
|
|
# Second call within cache TTL should use cache
|
|
result2 = get_bulk_news("24h")
|
|
call_count_2 = mock_fetch.call_count
|
|
|
|
# Fetch should not be called again if cache is working
|
|
# (Note: actual caching behavior depends on implementation)
|
|
assert isinstance(result1, list)
|
|
assert isinstance(result2, list)
|
|
|
|
@patch('tradingagents.dataflows.interface._fetch_bulk_news_from_vendor')
|
|
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
|
|
def test_get_bulk_news_converts_articles(self, mock_convert, mock_fetch):
|
|
"""Test that raw articles are converted to NewsArticle objects."""
|
|
mock_raw = [{"title": "Test"}]
|
|
mock_articles = [Mock(spec=NewsArticle)]
|
|
|
|
mock_fetch.return_value = mock_raw
|
|
mock_convert.return_value = mock_articles
|
|
|
|
result = get_bulk_news("24h")
|
|
|
|
mock_convert.assert_called_once_with(mock_raw)
|
|
assert result == mock_articles
|
|
|
|
|
|
class TestRouteToVendor:
|
|
"""Test suite for route_to_vendor function."""
|
|
|
|
@patch('tradingagents.dataflows.interface.get_vendor')
|
|
@patch('tradingagents.dataflows.interface.get_category_for_method')
|
|
def test_route_to_vendor_basic(self, mock_get_category, mock_get_vendor):
|
|
"""Test basic vendor routing."""
|
|
mock_get_category.return_value = "core_stock_apis"
|
|
mock_get_vendor.return_value = "yfinance"
|
|
|
|
# Mock the vendor function
|
|
with patch.dict(VENDOR_METHODS, {"get_stock_data": {"yfinance": Mock(return_value="test_data")}}):
|
|
result = route_to_vendor("get_stock_data", "AAPL", "2024-01-01")
|
|
|
|
assert result == "test_data"
|
|
|
|
@patch('tradingagents.dataflows.interface.get_vendor')
|
|
@patch('tradingagents.dataflows.interface.get_category_for_method')
|
|
def test_route_to_vendor_fallback(self, mock_get_category, mock_get_vendor):
|
|
"""Test vendor fallback when primary fails."""
|
|
mock_get_category.return_value = "news_data"
|
|
mock_get_vendor.return_value = "alpha_vantage"
|
|
|
|
# Mock primary vendor to fail, secondary to succeed
|
|
primary_mock = Mock(side_effect=Exception("Primary failed"))
|
|
secondary_mock = Mock(return_value="fallback_data")
|
|
|
|
with patch.dict(VENDOR_METHODS, {
|
|
"get_news": {
|
|
"alpha_vantage": primary_mock,
|
|
"openai": secondary_mock,
|
|
}
|
|
}):
|
|
result = route_to_vendor("get_news", "AAPL", "2024-01-01", "2024-01-31")
|
|
|
|
assert result == "fallback_data"
|
|
assert primary_mock.called
|
|
assert secondary_mock.called
|
|
|
|
@patch('tradingagents.dataflows.interface.get_vendor')
|
|
@patch('tradingagents.dataflows.interface.get_category_for_method')
|
|
def test_route_to_vendor_all_fail(self, mock_get_category, mock_get_vendor):
|
|
"""Test that RuntimeError is raised when all vendors fail."""
|
|
mock_get_category.return_value = "news_data"
|
|
mock_get_vendor.return_value = "alpha_vantage"
|
|
|
|
# All vendors fail
|
|
failing_mock = Mock(side_effect=Exception("Failed"))
|
|
|
|
with patch.dict(VENDOR_METHODS, {
|
|
"get_news": {
|
|
"alpha_vantage": failing_mock,
|
|
"openai": failing_mock,
|
|
}
|
|
}):
|
|
with pytest.raises(RuntimeError, match="All vendor implementations failed"):
|
|
route_to_vendor("get_news", "AAPL", "2024-01-01", "2024-01-31")
|
|
|
|
@patch('tradingagents.dataflows.interface.get_vendor')
|
|
@patch('tradingagents.dataflows.interface.get_category_for_method')
|
|
def test_route_to_vendor_multiple_results(self, mock_get_category, mock_get_vendor):
|
|
"""Test handling of multiple vendor implementations."""
|
|
mock_get_category.return_value = "news_data"
|
|
mock_get_vendor.return_value = "local"
|
|
|
|
# Local vendor has multiple implementations
|
|
impl1 = Mock(return_value="result1")
|
|
impl2 = Mock(return_value="result2")
|
|
|
|
with patch.dict(VENDOR_METHODS, {
|
|
"get_news": {
|
|
"local": [impl1, impl2],
|
|
}
|
|
}):
|
|
result = route_to_vendor("get_news", "AAPL", "2024-01-01", "2024-01-31")
|
|
|
|
# Should combine multiple results
|
|
assert isinstance(result, str)
|
|
assert impl1.called
|
|
assert impl2.called
|
|
|
|
def test_route_to_vendor_unsupported_method(self):
|
|
"""Test that ValueError is raised for unsupported methods."""
|
|
with pytest.raises(ValueError, match="not found in any category"):
|
|
route_to_vendor("nonexistent_method", "arg1")
|
|
|
|
|
|
class TestConvertToNewsArticles:
|
|
"""Test suite for _convert_to_news_articles function."""
|
|
|
|
@patch('tradingagents.dataflows.interface._convert_to_news_articles')
|
|
def test_convert_empty_list(self, mock_convert):
|
|
"""Test converting empty article list."""
|
|
mock_convert.return_value = []
|
|
|
|
from tradingagents.dataflows.interface import _convert_to_news_articles
|
|
result = _convert_to_news_articles([])
|
|
|
|
assert result == []
|
|
|
|
@patch('tradingagents.dataflows.interface.NewsArticle')
|
|
def test_convert_valid_articles(self, mock_news_article):
|
|
"""Test converting valid raw articles."""
|
|
from tradingagents.dataflows.interface import _convert_to_news_articles
|
|
|
|
raw_articles = [
|
|
{
|
|
"title": "Article 1",
|
|
"source": "Source 1",
|
|
"url": "https://example.com/1",
|
|
"published_at": datetime(2024, 1, 15).isoformat(),
|
|
"content_snippet": "Content 1",
|
|
}
|
|
]
|
|
|
|
result = _convert_to_news_articles(raw_articles)
|
|
|
|
# Should attempt to create NewsArticle
|
|
assert isinstance(result, list)
|
|
|
|
def test_convert_invalid_date_format(self):
|
|
"""Test handling of invalid date formats."""
|
|
from tradingagents.dataflows.interface import _convert_to_news_articles
|
|
|
|
raw_articles = [
|
|
{
|
|
"title": "Article",
|
|
"source": "Source",
|
|
"url": "https://example.com",
|
|
"published_at": "invalid_date",
|
|
"content_snippet": "Content",
|
|
}
|
|
]
|
|
|
|
result = _convert_to_news_articles(raw_articles)
|
|
|
|
# Should handle gracefully
|
|
assert isinstance(result, list) |