199 lines
6.9 KiB
Python
199 lines
6.9 KiB
Python
"""Unit tests for data interface routing."""
|
|
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from tradingagents.dataflows.interface import (
|
|
TOOLS_CATEGORIES,
|
|
VENDOR_LIST,
|
|
VENDOR_METHODS,
|
|
get_category_for_method,
|
|
get_vendor,
|
|
route_to_vendor,
|
|
)
|
|
|
|
|
|
class TestToolsCategories:
|
|
"""Tests for TOOLS_CATEGORIES structure."""
|
|
|
|
@pytest.mark.unit
|
|
def test_core_stock_apis_category_exists(self):
|
|
"""Test that core_stock_apis category exists."""
|
|
assert "core_stock_apis" in TOOLS_CATEGORIES
|
|
assert "get_stock_data" in TOOLS_CATEGORIES["core_stock_apis"]["tools"]
|
|
|
|
@pytest.mark.unit
|
|
def test_technical_indicators_category_exists(self):
|
|
"""Test that technical_indicators category exists."""
|
|
assert "technical_indicators" in TOOLS_CATEGORIES
|
|
assert "get_indicators" in TOOLS_CATEGORIES["technical_indicators"]["tools"]
|
|
|
|
@pytest.mark.unit
|
|
def test_fundamental_data_category_exists(self):
|
|
"""Test that fundamental_data category exists."""
|
|
assert "fundamental_data" in TOOLS_CATEGORIES
|
|
expected_tools = [
|
|
"get_fundamentals",
|
|
"get_balance_sheet",
|
|
"get_cashflow",
|
|
"get_income_statement",
|
|
]
|
|
for tool in expected_tools:
|
|
assert tool in TOOLS_CATEGORIES["fundamental_data"]["tools"]
|
|
|
|
@pytest.mark.unit
|
|
def test_news_data_category_exists(self):
|
|
"""Test that news_data category exists."""
|
|
assert "news_data" in TOOLS_CATEGORIES
|
|
expected_tools = ["get_news", "get_global_news", "get_insider_transactions"]
|
|
for tool in expected_tools:
|
|
assert tool in TOOLS_CATEGORIES["news_data"]["tools"]
|
|
|
|
|
|
class TestVendorList:
|
|
"""Tests for VENDOR_LIST."""
|
|
|
|
@pytest.mark.unit
|
|
def test_yfinance_in_vendor_list(self):
|
|
"""Test that yfinance is in vendor list."""
|
|
assert "yfinance" in VENDOR_LIST
|
|
|
|
@pytest.mark.unit
|
|
def test_alpha_vantage_in_vendor_list(self):
|
|
"""Test that alpha_vantage is in vendor list."""
|
|
assert "alpha_vantage" in VENDOR_LIST
|
|
|
|
@pytest.mark.unit
|
|
def test_vendor_list_length(self):
|
|
"""Test vendor list contains expected number of vendors."""
|
|
assert len(VENDOR_LIST) == 2
|
|
|
|
|
|
class TestGetCategoryForMethod:
|
|
"""Tests for get_category_for_method function."""
|
|
|
|
@pytest.mark.unit
|
|
def test_get_category_for_stock_data(self):
|
|
"""Test category for get_stock_data."""
|
|
category = get_category_for_method("get_stock_data")
|
|
assert category == "core_stock_apis"
|
|
|
|
@pytest.mark.unit
|
|
def test_get_category_for_indicators(self):
|
|
"""Test category for get_indicators."""
|
|
category = get_category_for_method("get_indicators")
|
|
assert category == "technical_indicators"
|
|
|
|
@pytest.mark.unit
|
|
def test_get_category_for_fundamentals(self):
|
|
"""Test category for get_fundamentals."""
|
|
category = get_category_for_method("get_fundamentals")
|
|
assert category == "fundamental_data"
|
|
|
|
@pytest.mark.unit
|
|
def test_get_category_for_news(self):
|
|
"""Test category for get_news."""
|
|
category = get_category_for_method("get_news")
|
|
assert category == "news_data"
|
|
|
|
@pytest.mark.unit
|
|
def test_get_category_for_invalid_method_raises(self):
|
|
"""Test that invalid method raises ValueError."""
|
|
with pytest.raises(ValueError, match="not found"):
|
|
get_category_for_method("invalid_method")
|
|
|
|
|
|
class TestGetVendor:
|
|
"""Tests for get_vendor function."""
|
|
|
|
@pytest.mark.unit
|
|
@patch("tradingagents.dataflows.interface.get_config")
|
|
def test_get_vendor_default(self, mock_get_config):
|
|
"""Test getting default vendor for a category."""
|
|
mock_get_config.return_value = {
|
|
"data_vendors": {"core_stock_apis": "yfinance"},
|
|
"tool_vendors": {},
|
|
}
|
|
|
|
vendor = get_vendor("core_stock_apis")
|
|
assert vendor == "yfinance"
|
|
|
|
@pytest.mark.unit
|
|
@patch("tradingagents.dataflows.interface.get_config")
|
|
def test_get_vendor_tool_level_override(self, mock_get_config):
|
|
"""Test that tool-level vendor takes precedence."""
|
|
mock_get_config.return_value = {
|
|
"data_vendors": {"core_stock_apis": "yfinance"},
|
|
"tool_vendors": {"get_stock_data": "alpha_vantage"},
|
|
}
|
|
|
|
vendor = get_vendor("core_stock_apis", "get_stock_data")
|
|
assert vendor == "alpha_vantage"
|
|
|
|
@pytest.mark.unit
|
|
@patch("tradingagents.dataflows.interface.get_config")
|
|
def test_get_vendor_missing_category_uses_default(self, mock_get_config):
|
|
"""Test that missing category returns 'default'."""
|
|
mock_get_config.return_value = {
|
|
"data_vendors": {},
|
|
"tool_vendors": {},
|
|
}
|
|
|
|
vendor = get_vendor("unknown_category")
|
|
assert vendor == "default"
|
|
|
|
|
|
class TestVendorMethods:
|
|
"""Tests for VENDOR_METHODS structure."""
|
|
|
|
@pytest.mark.unit
|
|
def test_get_stock_data_has_both_vendors(self):
|
|
"""Test that get_stock_data has both vendors."""
|
|
assert "yfinance" in VENDOR_METHODS["get_stock_data"]
|
|
assert "alpha_vantage" in VENDOR_METHODS["get_stock_data"]
|
|
|
|
@pytest.mark.unit
|
|
def test_all_methods_have_vendors(self):
|
|
"""Test that all methods have at least one vendor."""
|
|
for method, vendors in VENDOR_METHODS.items():
|
|
assert len(vendors) > 0, f"Method {method} has no vendors"
|
|
|
|
|
|
class TestRouteToVendor:
|
|
"""Tests for route_to_vendor function."""
|
|
|
|
@pytest.mark.unit
|
|
@patch("tradingagents.dataflows.interface.get_config")
|
|
def test_route_to_vendor_invalid_method_raises(self, mock_get_config):
|
|
"""Test that routing invalid method raises ValueError."""
|
|
mock_get_config.return_value = {"data_vendors": {}, "tool_vendors": {}}
|
|
|
|
with pytest.raises(ValueError, match="not found"):
|
|
route_to_vendor("invalid_method", "AAPL")
|
|
|
|
@pytest.mark.unit
|
|
@patch("tradingagents.dataflows.interface.get_config")
|
|
@patch("tradingagents.dataflows.interface.VENDOR_METHODS")
|
|
def test_route_to_vendor_fallback_on_rate_limit(self, mock_methods, mock_get_config):
|
|
"""Test that vendor fallback works on rate limit errors."""
|
|
mock_get_config.return_value = {
|
|
"data_vendors": {"core_stock_apis": "alpha_vantage"},
|
|
"tool_vendors": {},
|
|
}
|
|
|
|
# This test would need proper mocking of the actual vendor functions
|
|
# For now, we just verify the function signature exists
|
|
|
|
@pytest.mark.unit
|
|
@patch("tradingagents.dataflows.interface.get_config")
|
|
def test_route_to_vendor_no_available_vendor_raises(self, mock_get_config):
|
|
"""Test that no available vendor raises RuntimeError."""
|
|
mock_get_config.return_value = {
|
|
"data_vendors": {"core_stock_apis": "nonexistent_vendor"},
|
|
"tool_vendors": {},
|
|
}
|
|
|
|
# This test would verify that if all vendors fail, RuntimeError is raised
|
|
# Actual implementation depends on the real vendor functions
|