From 2c802647e4d59c0d00d4d83aa501b865ae008ce5 Mon Sep 17 00:00:00 2001 From: Andrew Kaszubski Date: Fri, 26 Dec 2025 16:47:41 +1100 Subject: [PATCH] feat(dataflows): add vendor registry pattern for extensible data vendor routing - Fixes #11 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements [DATA-10] Interface routing - add new data vendors with: - VendorRegistry: Thread-safe singleton for centralized vendor registration - VendorCapability enum: STOCK_DATA, FUNDAMENTALS, NEWS, MACROECONOMIC, etc. - BaseVendor ABC: 3-stage lifecycle (transform_query, extract_data, transform_data) - SimpleVendor: Wrapper for migrating existing vendor functions - Decorators: @register_vendor, @vendor_method, @rate_limited, @with_retry, @cache_result - RateLimiter: Thread-safe sliding window rate limiting - 84 tests covering registry, base vendor, and decorators Files: - tradingagents/dataflows/vendor_registry.py (253 lines) - tradingagents/dataflows/base_vendor.py (222 lines) - tradingagents/dataflows/vendor_decorators.py (188 lines) - tests/unit/dataflows/test_vendor_registry.py (30 tests) - tests/unit/dataflows/test_base_vendor.py (27 tests) - tests/unit/dataflows/test_vendor_decorators.py (27 tests) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- tests/unit/dataflows/test_base_vendor.py | 424 ++++++++++++++++++ .../unit/dataflows/test_vendor_decorators.py | 416 +++++++++++++++++ tests/unit/dataflows/test_vendor_registry.py | 354 +++++++++++++++ tradingagents/dataflows/base_vendor.py | 365 +++++++++++++++ tradingagents/dataflows/vendor_decorators.py | 351 +++++++++++++++ tradingagents/dataflows/vendor_registry.py | 329 ++++++++++++++ 6 files changed, 2239 insertions(+) create mode 100644 tests/unit/dataflows/test_base_vendor.py create mode 100644 tests/unit/dataflows/test_vendor_decorators.py create mode 100644 tests/unit/dataflows/test_vendor_registry.py create mode 100644 tradingagents/dataflows/base_vendor.py create mode 100644 tradingagents/dataflows/vendor_decorators.py create mode 100644 tradingagents/dataflows/vendor_registry.py diff --git a/tests/unit/dataflows/test_base_vendor.py b/tests/unit/dataflows/test_base_vendor.py new file mode 100644 index 00000000..0888b1f7 --- /dev/null +++ b/tests/unit/dataflows/test_base_vendor.py @@ -0,0 +1,424 @@ +"""Tests for BaseVendor abstract class. + +Issue #11: [DATA-10] Interface routing - add new data vendors +""" + +import pytest +import time +import threading +from typing import Dict, Any +from unittest.mock import Mock, patch + +from tradingagents.dataflows.base_vendor import ( + BaseVendor, + SimpleVendor, + VendorResponse, +) + +pytestmark = pytest.mark.unit + + +class ConcreteVendor(BaseVendor): + """Concrete implementation for testing.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._vendor_name = "test_vendor" + + @property + def name(self) -> str: + return self._vendor_name + + def _transform_query(self, method: str, **kwargs) -> Dict[str, Any]: + return {"method": method, **kwargs} + + def _extract_data(self, method: str, query: Dict[str, Any]) -> Any: + return {"raw": "data", "query": query} + + def _transform_data(self, method: str, raw_data: Any, query: Dict[str, Any]) -> Any: + return {"transformed": raw_data} + + +class TestVendorResponse: + """Tests for VendorResponse dataclass.""" + + def test_default_values(self): + """Test default values are set correctly.""" + response = VendorResponse() + assert response.data is None + assert response.success is True + assert response.vendor_name == "" + assert response.method_name == "" + assert response.execution_time_ms == 0.0 + assert response.error_message == "" + assert response.metadata == {} + + def test_custom_values(self): + """Test custom values are preserved.""" + response = VendorResponse( + data={"test": "data"}, + success=True, + vendor_name="yfinance", + method_name="get_stock", + execution_time_ms=150.5, + metadata={"source": "api"} + ) + assert response.data == {"test": "data"} + assert response.vendor_name == "yfinance" + assert response.method_name == "get_stock" + assert response.execution_time_ms == 150.5 + + def test_failure_response(self): + """Test failed response.""" + response = VendorResponse( + success=False, + error_message="API Error" + ) + assert response.success is False + assert response.error_message == "API Error" + + def test_is_empty_with_none(self): + """Test is_empty returns True for None data.""" + response = VendorResponse(data=None) + assert response.is_empty is True + + def test_is_empty_with_empty_list(self): + """Test is_empty returns True for empty list.""" + response = VendorResponse(data=[]) + assert response.is_empty is True + + def test_is_empty_with_empty_dict(self): + """Test is_empty returns True for empty dict.""" + response = VendorResponse(data={}) + assert response.is_empty is True + + def test_is_empty_with_empty_string(self): + """Test is_empty returns True for empty string.""" + response = VendorResponse(data="") + assert response.is_empty is True + + def test_is_empty_with_data(self): + """Test is_empty returns False when data exists.""" + response = VendorResponse(data={"test": "data"}) + assert response.is_empty is False + + +class TestBaseVendorAbstract: + """Tests for BaseVendor abstract method enforcement.""" + + def test_cannot_instantiate_base_vendor(self): + """Test that BaseVendor cannot be instantiated directly.""" + with pytest.raises(TypeError): + BaseVendor() + + def test_can_instantiate_concrete_vendor(self): + """Test that concrete implementation can be instantiated.""" + vendor = ConcreteVendor() + assert vendor is not None + assert vendor.name == "test_vendor" + + +class TestBaseVendorConfiguration: + """Tests for BaseVendor configuration.""" + + def test_default_configuration(self): + """Test default configuration values.""" + vendor = ConcreteVendor() + assert vendor._max_retries == 3 + assert vendor._retry_delay == 1.0 + assert vendor._retry_backoff == 2.0 + assert vendor._timeout is None + + def test_custom_configuration(self): + """Test custom configuration values.""" + vendor = ConcreteVendor( + max_retries=5, + retry_delay=0.5, + retry_backoff=3.0, + timeout=30.0 + ) + assert vendor._max_retries == 5 + assert vendor._retry_delay == 0.5 + assert vendor._retry_backoff == 3.0 + assert vendor._timeout == 30.0 + + +class TestBaseVendorExecution: + """Tests for BaseVendor execute method.""" + + def test_successful_execution(self): + """Test successful execution returns VendorResponse.""" + vendor = ConcreteVendor() + response = vendor.execute("get_data", ticker="AAPL") + + assert isinstance(response, VendorResponse) + assert response.success is True + assert response.vendor_name == "test_vendor" + assert response.method_name == "get_data" + assert response.execution_time_ms > 0 + + def test_execution_increments_call_count(self): + """Test that execute increments call count.""" + vendor = ConcreteVendor() + assert vendor.call_count == 0 + + vendor.execute("method1") + assert vendor.call_count == 1 + + vendor.execute("method2") + assert vendor.call_count == 2 + + def test_call_count_thread_safe(self): + """Test that call_count is thread-safe.""" + vendor = ConcreteVendor() + errors = [] + + def execute_many(): + try: + for _ in range(100): + vendor.execute("test") + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=execute_many) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + assert vendor.call_count == 500 + + +class TestBaseVendorRetry: + """Tests for BaseVendor retry logic.""" + + def test_retry_on_connection_error(self): + """Test retry on connection error.""" + attempt_count = [0] + + class RetryVendor(ConcreteVendor): + def _extract_data(self, method, query): + attempt_count[0] += 1 + if attempt_count[0] < 3: + raise ConnectionError("Network error") + return {"data": "success"} + + vendor = RetryVendor(retry_delay=0.01) + response = vendor.execute("test") + + assert attempt_count[0] == 3 + assert response.success is True + + def test_no_retry_on_value_error(self): + """Test no retry on non-retryable error.""" + class NoRetryVendor(ConcreteVendor): + def _extract_data(self, method, query): + raise ValueError("Invalid input") + + vendor = NoRetryVendor(retry_delay=0.01) + response = vendor.execute("test") + + assert response.success is False + assert "Invalid input" in response.error_message + + def test_max_retries_exhausted(self): + """Test that max retries is respected.""" + attempt_count = [0] + + class AlwaysFailVendor(ConcreteVendor): + def _extract_data(self, method, query): + attempt_count[0] += 1 + raise ConnectionError("Network error") + + vendor = AlwaysFailVendor(max_retries=3, retry_delay=0.01) + response = vendor.execute("test") + + assert attempt_count[0] == 4 # 1 initial + 3 retries + assert response.success is False + + +class TestBaseVendorStatistics: + """Tests for BaseVendor statistics.""" + + def test_error_rate_calculation(self): + """Test error rate calculation.""" + class MixedVendor(ConcreteVendor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._fail_count = 0 + + def _extract_data(self, method, query): + self._fail_count += 1 + if self._fail_count <= 2: + raise ValueError("Error") + return {"data": "success"} + + vendor = MixedVendor(max_retries=0, retry_delay=0.01) + + # 2 failures + vendor.execute("test") + vendor.execute("test") + # 1 success + vendor.execute("test") + + # 2 errors out of 3 = 66.67% + assert 66 < vendor.error_rate < 67 + + def test_reset_stats(self): + """Test reset_stats clears counters.""" + vendor = ConcreteVendor() + vendor.execute("test") + vendor.execute("test") + + assert vendor.call_count == 2 + + vendor.reset_stats() + + assert vendor.call_count == 0 + assert vendor.error_rate == 0.0 + + +class TestSimpleVendor: + """Tests for SimpleVendor wrapper class.""" + + def test_simple_vendor_creation(self): + """Test creating a SimpleVendor.""" + mock_func = Mock(return_value={"data": "test"}) + vendor = SimpleVendor( + vendor_name="test", + methods={"get_data": mock_func} + ) + + assert vendor.name == "test" + + def test_simple_vendor_execution(self): + """Test executing a SimpleVendor method.""" + mock_func = Mock(return_value={"data": "test"}) + vendor = SimpleVendor( + vendor_name="test", + methods={"get_data": mock_func} + ) + + response = vendor.execute("get_data", ticker="AAPL") + + assert response.success is True + assert response.data == {"data": "test"} + mock_func.assert_called_once_with(ticker="AAPL") + + def test_simple_vendor_missing_method(self): + """Test SimpleVendor with missing method.""" + vendor = SimpleVendor( + vendor_name="test", + methods={} + ) + + response = vendor.execute("nonexistent") + + assert response.success is False + assert "not found" in response.error_message + + def test_simple_vendor_add_method(self): + """Test adding a method to SimpleVendor.""" + vendor = SimpleVendor( + vendor_name="test", + methods={} + ) + + mock_func = Mock(return_value="result") + vendor.add_method("new_method", mock_func) + + response = vendor.execute("new_method") + + assert response.success is True + assert response.data == "result" + + def test_simple_vendor_get_methods(self): + """Test getting list of methods.""" + vendor = SimpleVendor( + vendor_name="test", + methods={ + "method1": Mock(), + "method2": Mock(), + "method3": Mock() + } + ) + + methods = vendor.get_methods() + + assert len(methods) == 3 + assert "method1" in methods + assert "method2" in methods + assert "method3" in methods + + +class TestVendorLifecycle: + """Tests for 3-stage vendor lifecycle.""" + + def test_lifecycle_order(self): + """Test that lifecycle stages execute in order.""" + call_order = [] + + class OrderVendor(BaseVendor): + @property + def name(self): + return "order_test" + + def _transform_query(self, method, **kwargs): + call_order.append("transform_query") + return kwargs + + def _extract_data(self, method, query): + call_order.append("extract_data") + return {"raw": True} + + def _transform_data(self, method, raw_data, query): + call_order.append("transform_data") + return raw_data + + vendor = OrderVendor() + vendor.execute("test") + + assert call_order == ["transform_query", "extract_data", "transform_data"] + + def test_query_passed_to_extract(self): + """Test that transformed query is passed to extract.""" + class QueryVendor(BaseVendor): + @property + def name(self): + return "query_test" + + def _transform_query(self, method, **kwargs): + return {"ticker": kwargs.get("ticker", "").upper()} + + def _extract_data(self, method, query): + return {"symbol": query["ticker"]} + + def _transform_data(self, method, raw_data, query): + return raw_data + + vendor = QueryVendor() + response = vendor.execute("test", ticker="aapl") + + assert response.data["symbol"] == "AAPL" + + def test_raw_data_passed_to_transform(self): + """Test that raw data is passed to transform.""" + class TransformVendor(BaseVendor): + @property + def name(self): + return "transform_test" + + def _transform_query(self, method, **kwargs): + return kwargs + + def _extract_data(self, method, query): + return {"raw_value": 42} + + def _transform_data(self, method, raw_data, query): + return {"processed": raw_data["raw_value"] * 2} + + vendor = TransformVendor() + response = vendor.execute("test") + + assert response.data["processed"] == 84 diff --git a/tests/unit/dataflows/test_vendor_decorators.py b/tests/unit/dataflows/test_vendor_decorators.py new file mode 100644 index 00000000..083ac9a0 --- /dev/null +++ b/tests/unit/dataflows/test_vendor_decorators.py @@ -0,0 +1,416 @@ +"""Tests for vendor decorators. + +Issue #11: [DATA-10] Interface routing - add new data vendors +""" + +import pytest +import time +import threading +from unittest.mock import Mock + +from tradingagents.dataflows.vendor_registry import ( + VendorCapability, + VendorMetadata, + VendorRegistry, +) +from tradingagents.dataflows.vendor_decorators import ( + register_vendor, + vendor_method, + rate_limited, + with_retry, + cache_result, + RateLimiter, + get_rate_limiter, +) + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def reset_registry(): + """Reset the singleton registry before each test.""" + VendorRegistry.reset_instance() + yield + VendorRegistry.reset_instance() + + +class TestRegisterVendorDecorator: + """Tests for @register_vendor decorator.""" + + def test_register_vendor_basic(self): + """Test basic vendor registration with decorator.""" + @register_vendor( + name="test_vendor", + capabilities={VendorCapability.STOCK_DATA}, + priority=10 + ) + class TestVendor: + pass + + assert hasattr(TestVendor, '_vendor_metadata') + assert TestVendor._vendor_name == "test_vendor" + + def test_register_vendor_adds_to_registry(self): + """Test that decorator adds vendor to registry.""" + @register_vendor( + name="registered_vendor", + capabilities={VendorCapability.STOCK_DATA} + ) + class RegisteredVendor: + pass + + registry = VendorRegistry() + assert registry.get_vendor("registered_vendor") is not None + + def test_register_vendor_with_multiple_capabilities(self): + """Test registering vendor with multiple capabilities.""" + @register_vendor( + name="multi_vendor", + capabilities={ + VendorCapability.STOCK_DATA, + VendorCapability.FUNDAMENTALS, + VendorCapability.NEWS + } + ) + class MultiVendor: + pass + + registry = VendorRegistry() + vendor = registry.get_vendor("multi_vendor") + assert len(vendor.capabilities) == 3 + + def test_register_vendor_preserves_class_methods(self): + """Test that decorator preserves class methods.""" + @register_vendor( + name="method_vendor", + capabilities={VendorCapability.STOCK_DATA} + ) + class MethodVendor: + def fetch_data(self): + return "data" + + vendor = MethodVendor() + assert vendor.fetch_data() == "data" + + +class TestVendorMethodDecorator: + """Tests for @vendor_method decorator.""" + + def test_vendor_method_basic(self): + """Test basic method mapping.""" + class TestVendor: + @vendor_method("get_stock_data") + def fetch_stock(self, ticker): + return f"Stock: {ticker}" + + vendor = TestVendor() + result = vendor.fetch_stock("AAPL") + + assert result == "Stock: AAPL" + assert vendor.fetch_stock._vendor_method_name == "get_stock_data" + + def test_vendor_method_with_vendor_name(self): + """Test vendor_method with explicit vendor name.""" + # Register the vendor first + registry = VendorRegistry() + registry.register_vendor(VendorMetadata(name="yfinance")) + + @vendor_method("get_stock", vendor_name="yfinance") + def get_yfinance_stock(ticker): + return f"YF: {ticker}" + + # Method should be registered + method = registry.get_method("yfinance", "get_stock") + assert method is not None + + def test_vendor_method_preserves_function(self): + """Test that decorator preserves function behavior.""" + class TestVendor: + @vendor_method("get_data") + def fetch_data(self, ticker, start_date=None): + return {"ticker": ticker, "start": start_date} + + vendor = TestVendor() + result = vendor.fetch_data("AAPL", start_date="2024-01-01") + + assert result["ticker"] == "AAPL" + assert result["start"] == "2024-01-01" + + +class TestRateLimiter: + """Tests for RateLimiter class.""" + + def test_rate_limiter_allows_calls_under_limit(self): + """Test that calls under limit are allowed.""" + limiter = RateLimiter(max_calls=5, period_seconds=60) + + for _ in range(5): + assert limiter.acquire() is True + + def test_rate_limiter_blocks_over_limit(self): + """Test that calls over limit are blocked.""" + limiter = RateLimiter(max_calls=3, period_seconds=60) + + # Use all slots + for _ in range(3): + limiter.acquire() + + # 4th call should be blocked + assert limiter.acquire() is False + + def test_rate_limiter_wait_time(self): + """Test wait_time calculation.""" + limiter = RateLimiter(max_calls=2, period_seconds=60) + + # Fill up the slots + limiter.acquire() + limiter.acquire() + + # Should have positive wait time + wait = limiter.wait_time() + assert wait > 0 + assert wait <= 60 + + def test_rate_limiter_reset(self): + """Test reset clears the limiter.""" + limiter = RateLimiter(max_calls=2, period_seconds=60) + + limiter.acquire() + limiter.acquire() + assert limiter.acquire() is False + + limiter.reset() + assert limiter.acquire() is True + + +class TestRateLimitedDecorator: + """Tests for @rate_limited decorator.""" + + def test_rate_limited_allows_calls_under_limit(self): + """Test that calls under limit succeed.""" + call_count = [0] + + @rate_limited(max_calls=5, period_seconds=60) + def limited_func(): + call_count[0] += 1 + return "success" + + for _ in range(5): + result = limited_func() + assert result == "success" + + assert call_count[0] == 5 + + def test_rate_limited_blocks_over_limit(self): + """Test that calls over limit raise error.""" + @rate_limited(max_calls=2, period_seconds=60) + def limited_func(): + return "success" + + limited_func() + limited_func() + + with pytest.raises(RuntimeError, match="Rate limit exceeded"): + limited_func() + + def test_rate_limited_custom_exception(self): + """Test rate limiting with custom exception class.""" + class CustomRateLimitError(Exception): + pass + + @rate_limited( + max_calls=1, + period_seconds=60, + exception_class=CustomRateLimitError + ) + def limited_func(): + return "success" + + limited_func() + + with pytest.raises(CustomRateLimitError): + limited_func() + + +class TestWithRetryDecorator: + """Tests for @with_retry decorator.""" + + def test_retry_on_retryable_error(self): + """Test retry on retryable exceptions.""" + attempt_count = [0] + + @with_retry(max_retries=3, retry_delay=0.01) + def failing_func(): + attempt_count[0] += 1 + if attempt_count[0] < 3: + raise ConnectionError("Network error") + return "success" + + result = failing_func() + assert result == "success" + assert attempt_count[0] == 3 + + def test_no_retry_on_non_retryable_error(self): + """Test no retry on non-retryable exceptions.""" + attempt_count = [0] + + @with_retry( + max_retries=3, + retry_delay=0.01, + retryable_exceptions=(ConnectionError,) + ) + def failing_func(): + attempt_count[0] += 1 + raise ValueError("Not retryable") + + with pytest.raises(ValueError): + failing_func() + + assert attempt_count[0] == 1 + + def test_retry_exhausted(self): + """Test exception raised after max retries.""" + @with_retry(max_retries=2, retry_delay=0.01) + def always_fails(): + raise ConnectionError("Always fails") + + with pytest.raises(ConnectionError): + always_fails() + + +class TestCacheResultDecorator: + """Tests for @cache_result decorator.""" + + def test_cache_returns_cached_value(self): + """Test that cached value is returned.""" + call_count = [0] + + @cache_result(ttl_seconds=300) + def expensive_func(key): + call_count[0] += 1 + return f"value_{key}" + + # First call + result1 = expensive_func("test") + # Second call (should use cache) + result2 = expensive_func("test") + + assert result1 == result2 + assert call_count[0] == 1 + + def test_cache_different_keys(self): + """Test that different keys have different cache entries.""" + call_count = [0] + + @cache_result(ttl_seconds=300) + def expensive_func(key): + call_count[0] += 1 + return f"value_{key}" + + result1 = expensive_func("key1") + result2 = expensive_func("key2") + + assert result1 != result2 + assert call_count[0] == 2 + + def test_cache_clear(self): + """Test clearing the cache.""" + call_count = [0] + + @cache_result(ttl_seconds=300) + def expensive_func(key): + call_count[0] += 1 + return f"value_{key}" + + expensive_func("test") + expensive_func.clear_cache() + expensive_func("test") + + assert call_count[0] == 2 + + def test_cache_info(self): + """Test cache info method.""" + @cache_result(ttl_seconds=60, max_size=100) + def cached_func(key): + return key + + cached_func("test1") + cached_func("test2") + + info = cached_func.cache_info() + assert info["size"] == 2 + assert info["max_size"] == 100 + assert info["ttl_seconds"] == 60 + + +class TestDecoratorStacking: + """Tests for stacking multiple decorators.""" + + def test_register_and_method_decorators(self): + """Test stacking @register_vendor and @vendor_method.""" + @register_vendor( + name="stacked_vendor", + capabilities={VendorCapability.STOCK_DATA} + ) + class StackedVendor: + @vendor_method("get_stock_data") + def fetch_stock(self, ticker): + return f"Stock: {ticker}" + + vendor = StackedVendor() + + # Vendor should be registered + registry = VendorRegistry() + assert registry.get_vendor("stacked_vendor") is not None + + # Method should work + result = vendor.fetch_stock("AAPL") + assert result == "Stock: AAPL" + + def test_method_with_rate_limit(self): + """Test stacking @vendor_method with @rate_limited.""" + call_count = [0] + + class TestVendor: + @vendor_method("get_data") + @rate_limited(max_calls=3, period_seconds=60) + def fetch_data(self): + call_count[0] += 1 + return "data" + + vendor = TestVendor() + + # Should work 3 times + for _ in range(3): + vendor.fetch_data() + + # 4th should fail + with pytest.raises(RuntimeError): + vendor.fetch_data() + + assert call_count[0] == 3 + + +class TestGetRateLimiter: + """Tests for get_rate_limiter function.""" + + def test_returns_same_limiter_for_same_params(self): + """Test that same params return same limiter.""" + limiter1 = get_rate_limiter("vendor", 5, 60) + limiter2 = get_rate_limiter("vendor", 5, 60) + + assert limiter1 is limiter2 + + def test_returns_different_limiter_for_different_vendor(self): + """Test that different vendors get different limiters.""" + limiter1 = get_rate_limiter("vendor1", 5, 60) + limiter2 = get_rate_limiter("vendor2", 5, 60) + + assert limiter1 is not limiter2 + + def test_returns_different_limiter_for_different_params(self): + """Test that different params get different limiters.""" + limiter1 = get_rate_limiter("vendor", 5, 60) + limiter2 = get_rate_limiter("vendor", 10, 60) + + assert limiter1 is not limiter2 diff --git a/tests/unit/dataflows/test_vendor_registry.py b/tests/unit/dataflows/test_vendor_registry.py new file mode 100644 index 00000000..63574a54 --- /dev/null +++ b/tests/unit/dataflows/test_vendor_registry.py @@ -0,0 +1,354 @@ +"""Tests for VendorRegistry. + +Issue #11: [DATA-10] Interface routing - add new data vendors +""" + +import pytest +import threading +from unittest.mock import Mock + +from tradingagents.dataflows.vendor_registry import ( + VendorCapability, + VendorMetadata, + VendorRegistry, + get_registry, + register_vendor, + register_method, +) + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def reset_registry(): + """Reset the singleton registry before each test.""" + VendorRegistry.reset_instance() + yield + VendorRegistry.reset_instance() + + +class TestVendorMetadata: + """Tests for VendorMetadata dataclass.""" + + def test_default_values(self): + """Test default values are set correctly.""" + metadata = VendorMetadata(name="test") + assert metadata.name == "test" + assert metadata.priority == 100 + assert metadata.capabilities == set() + assert metadata.rate_limit_exception is None + assert metadata.description == "" + assert metadata.enabled is True + + def test_custom_values(self): + """Test custom values are preserved.""" + metadata = VendorMetadata( + name="yfinance", + priority=10, + capabilities={VendorCapability.STOCK_DATA}, + description="Yahoo Finance vendor" + ) + assert metadata.name == "yfinance" + assert metadata.priority == 10 + assert VendorCapability.STOCK_DATA in metadata.capabilities + assert metadata.description == "Yahoo Finance vendor" + + def test_hash(self): + """Test VendorMetadata is hashable by name.""" + m1 = VendorMetadata(name="vendor1") + m2 = VendorMetadata(name="vendor1", priority=50) + assert hash(m1) == hash(m2) + + def test_equality(self): + """Test VendorMetadata equality is by name.""" + m1 = VendorMetadata(name="vendor1", priority=10) + m2 = VendorMetadata(name="vendor1", priority=50) + m3 = VendorMetadata(name="vendor2") + assert m1 == m2 + assert m1 != m3 + + def test_equality_with_non_metadata(self): + """Test equality with non-VendorMetadata objects.""" + metadata = VendorMetadata(name="test") + assert metadata != "test" + assert metadata != {"name": "test"} + + +class TestVendorRegistry: + """Tests for VendorRegistry singleton.""" + + def test_singleton_pattern(self): + """Test that VendorRegistry is a singleton.""" + r1 = VendorRegistry() + r2 = VendorRegistry() + assert r1 is r2 + + def test_get_registry_returns_singleton(self): + """Test get_registry returns same instance.""" + r1 = get_registry() + r2 = VendorRegistry() + assert r1 is r2 + + def test_reset_instance(self): + """Test reset_instance creates new singleton.""" + r1 = VendorRegistry() + VendorRegistry.reset_instance() + r2 = VendorRegistry() + assert r1 is not r2 + + def test_register_vendor(self): + """Test registering a vendor.""" + registry = VendorRegistry() + metadata = VendorMetadata( + name="yfinance", + capabilities={VendorCapability.STOCK_DATA} + ) + registry.register_vendor(metadata) + + result = registry.get_vendor("yfinance") + assert result is not None + assert result.name == "yfinance" + + def test_register_vendor_empty_name_raises(self): + """Test registering vendor with empty name raises ValueError.""" + registry = VendorRegistry() + with pytest.raises(ValueError, match="cannot be empty"): + registry.register_vendor(VendorMetadata(name="")) + + def test_unregister_vendor(self): + """Test unregistering a vendor.""" + registry = VendorRegistry() + metadata = VendorMetadata(name="test_vendor") + registry.register_vendor(metadata) + + assert registry.unregister_vendor("test_vendor") is True + assert registry.get_vendor("test_vendor") is None + + def test_unregister_nonexistent_vendor(self): + """Test unregistering non-existent vendor returns False.""" + registry = VendorRegistry() + assert registry.unregister_vendor("nonexistent") is False + + def test_get_all_vendors_sorted_by_priority(self): + """Test get_all_vendors returns vendors sorted by priority.""" + registry = VendorRegistry() + registry.register_vendor(VendorMetadata(name="low", priority=100)) + registry.register_vendor(VendorMetadata(name="high", priority=10)) + registry.register_vendor(VendorMetadata(name="medium", priority=50)) + + vendors = registry.get_all_vendors() + assert [v.name for v in vendors] == ["high", "medium", "low"] + + def test_get_vendors_for_capability(self): + """Test getting vendors by capability.""" + registry = VendorRegistry() + registry.register_vendor(VendorMetadata( + name="yfinance", + capabilities={VendorCapability.STOCK_DATA, VendorCapability.FUNDAMENTALS} + )) + registry.register_vendor(VendorMetadata( + name="alpha_vantage", + capabilities={VendorCapability.STOCK_DATA} + )) + + stock_vendors = registry.get_vendors_for_capability(VendorCapability.STOCK_DATA) + assert len(stock_vendors) == 2 + + fundamental_vendors = registry.get_vendors_for_capability(VendorCapability.FUNDAMENTALS) + assert len(fundamental_vendors) == 1 + assert fundamental_vendors[0].name == "yfinance" + + def test_get_vendors_for_capability_only_enabled(self): + """Test only enabled vendors are returned by default.""" + registry = VendorRegistry() + registry.register_vendor(VendorMetadata( + name="enabled_vendor", + capabilities={VendorCapability.STOCK_DATA} + )) + registry.register_vendor(VendorMetadata( + name="disabled_vendor", + capabilities={VendorCapability.STOCK_DATA}, + enabled=False + )) + + vendors = registry.get_vendors_for_capability(VendorCapability.STOCK_DATA) + assert len(vendors) == 1 + assert vendors[0].name == "enabled_vendor" + + all_vendors = registry.get_vendors_for_capability( + VendorCapability.STOCK_DATA, + only_enabled=False + ) + assert len(all_vendors) == 2 + + +class TestVendorRegistryMethods: + """Tests for method registration in VendorRegistry.""" + + def test_register_method(self): + """Test registering a method for a vendor.""" + registry = VendorRegistry() + registry.register_vendor(VendorMetadata(name="yfinance")) + + mock_func = Mock() + registry.register_method("yfinance", "get_stock", mock_func) + + result = registry.get_method("yfinance", "get_stock") + assert result is mock_func + + def test_register_method_unregistered_vendor_raises(self): + """Test registering method for unregistered vendor raises.""" + registry = VendorRegistry() + with pytest.raises(ValueError, match="not registered"): + registry.register_method("nonexistent", "method", Mock()) + + def test_register_method_empty_name_raises(self): + """Test registering method with empty name raises.""" + registry = VendorRegistry() + registry.register_vendor(VendorMetadata(name="test")) + with pytest.raises(ValueError, match="cannot be empty"): + registry.register_method("test", "", Mock()) + + def test_get_method_nonexistent(self): + """Test getting non-existent method returns None.""" + registry = VendorRegistry() + registry.register_vendor(VendorMetadata(name="test")) + assert registry.get_method("test", "nonexistent") is None + + def test_get_methods_for_vendor(self): + """Test getting all methods for a vendor.""" + registry = VendorRegistry() + registry.register_vendor(VendorMetadata(name="yfinance")) + + func1 = Mock() + func2 = Mock() + registry.register_method("yfinance", "get_stock", func1) + registry.register_method("yfinance", "get_fundamentals", func2) + + methods = registry.get_methods_for_vendor("yfinance") + assert len(methods) == 2 + assert methods["get_stock"] is func1 + assert methods["get_fundamentals"] is func2 + + def test_get_methods_for_unregistered_vendor(self): + """Test getting methods for unregistered vendor returns empty dict.""" + registry = VendorRegistry() + assert registry.get_methods_for_vendor("nonexistent") == {} + + +class TestVendorRegistryVendorControl: + """Tests for vendor enable/disable and priority control.""" + + def test_set_vendor_enabled(self): + """Test enabling/disabling a vendor.""" + registry = VendorRegistry() + registry.register_vendor(VendorMetadata(name="test", enabled=True)) + + assert registry.set_vendor_enabled("test", False) is True + assert registry.get_vendor("test").enabled is False + + assert registry.set_vendor_enabled("test", True) is True + assert registry.get_vendor("test").enabled is True + + def test_set_vendor_enabled_nonexistent(self): + """Test setting enabled on non-existent vendor.""" + registry = VendorRegistry() + assert registry.set_vendor_enabled("nonexistent", True) is False + + def test_set_vendor_priority(self): + """Test updating vendor priority.""" + registry = VendorRegistry() + registry.register_vendor(VendorMetadata(name="test", priority=100)) + + assert registry.set_vendor_priority("test", 10) is True + assert registry.get_vendor("test").priority == 10 + + def test_set_vendor_priority_nonexistent(self): + """Test setting priority on non-existent vendor.""" + registry = VendorRegistry() + assert registry.set_vendor_priority("nonexistent", 10) is False + + def test_clear(self): + """Test clearing all registrations.""" + registry = VendorRegistry() + registry.register_vendor(VendorMetadata( + name="test", + capabilities={VendorCapability.STOCK_DATA} + )) + registry.register_method("test", "method", Mock()) + + registry.clear() + + assert registry.get_vendor("test") is None + assert len(registry.get_all_vendors()) == 0 + + +class TestVendorRegistryThreadSafety: + """Tests for thread safety of VendorRegistry.""" + + def test_concurrent_registration(self): + """Test concurrent vendor registration is thread-safe.""" + registry = VendorRegistry() + errors = [] + + def register_vendor(i): + try: + registry.register_vendor(VendorMetadata( + name=f"vendor_{i}", + capabilities={VendorCapability.STOCK_DATA} + )) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=register_vendor, args=(i,)) for i in range(50)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + assert len(registry.get_all_vendors()) == 50 + + def test_concurrent_method_registration(self): + """Test concurrent method registration is thread-safe.""" + registry = VendorRegistry() + registry.register_vendor(VendorMetadata(name="test")) + errors = [] + + def register_method(i): + try: + registry.register_method("test", f"method_{i}", Mock()) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=register_method, args=(i,)) for i in range(50)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + assert len(registry.get_methods_for_vendor("test")) == 50 + + +class TestModuleFunctions: + """Tests for module-level convenience functions.""" + + def test_register_vendor_function(self): + """Test module-level register_vendor function.""" + metadata = VendorMetadata(name="module_test") + register_vendor(metadata) + + registry = get_registry() + assert registry.get_vendor("module_test") is not None + + def test_register_method_function(self): + """Test module-level register_method function.""" + metadata = VendorMetadata(name="method_test") + register_vendor(metadata) + + mock_func = Mock() + register_method("method_test", "test_method", mock_func) + + registry = get_registry() + assert registry.get_method("method_test", "test_method") is mock_func diff --git a/tradingagents/dataflows/base_vendor.py b/tradingagents/dataflows/base_vendor.py new file mode 100644 index 00000000..2f37a213 --- /dev/null +++ b/tradingagents/dataflows/base_vendor.py @@ -0,0 +1,365 @@ +"""Base vendor abstract class for data provider implementations. + +This module provides the abstract base class that all data vendors must inherit from, +implementing a 3-stage data pipeline: transform_query → extract_data → transform_data. + +Issue #11: [DATA-10] Interface routing - add new data vendors +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Dict, List, Optional, TypeVar, Generic +import threading +import time +import logging + +logger = logging.getLogger(__name__) + +T = TypeVar('T') + + +@dataclass +class VendorResponse(Generic[T]): + """Standardized response from vendor data operations. + + Attributes: + data: The extracted data + success: Whether the operation succeeded + vendor_name: Name of the vendor that provided the data + method_name: Name of the method called + execution_time_ms: Time taken in milliseconds + error_message: Error message if operation failed + metadata: Additional metadata about the response + """ + data: Optional[T] = None + success: bool = True + vendor_name: str = "" + method_name: str = "" + execution_time_ms: float = 0.0 + error_message: str = "" + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def is_empty(self) -> bool: + """Check if response contains no data.""" + if self.data is None: + return True + if isinstance(self.data, (list, dict, str)): + return len(self.data) == 0 + return False + + +class BaseVendor(ABC): + """Abstract base class for all data vendors. + + Implements a 3-stage pipeline pattern: + 1. transform_query: Normalize input parameters + 2. extract_data: Fetch raw data from source + 3. transform_data: Normalize output format + + Thread-safe call counting with configurable retry logic. + + Example: + class YFinanceVendor(BaseVendor): + @property + def name(self) -> str: + return "yfinance" + + def _transform_query(self, method, **kwargs): + # Normalize ticker symbol + return {"symbol": kwargs["ticker"].upper()} + + def _extract_data(self, method, query): + # Fetch from yfinance + return yf.Ticker(query["symbol"]).history() + + def _transform_data(self, method, raw_data, query): + # Convert to standard format + return {"ohlcv": raw_data.to_dict()} + """ + + def __init__( + self, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + timeout: Optional[float] = None + ): + """Initialize vendor with retry configuration. + + Args: + max_retries: Maximum number of retry attempts + retry_delay: Initial delay between retries in seconds + retry_backoff: Multiplier for delay after each retry + timeout: Optional timeout for operations in seconds + """ + self._max_retries = max_retries + self._retry_delay = retry_delay + self._retry_backoff = retry_backoff + self._timeout = timeout + self._call_count = 0 + self._call_count_lock = threading.Lock() + self._last_call_time: Optional[datetime] = None + self._error_count = 0 + self._success_count = 0 + + @property + @abstractmethod + def name(self) -> str: + """Return the vendor name.""" + pass + + @property + def call_count(self) -> int: + """Get the total number of calls made (thread-safe).""" + with self._call_count_lock: + return self._call_count + + @property + def error_rate(self) -> float: + """Calculate error rate as percentage.""" + total = self._error_count + self._success_count + if total == 0: + return 0.0 + return (self._error_count / total) * 100 + + def reset_stats(self) -> None: + """Reset call statistics.""" + with self._call_count_lock: + self._call_count = 0 + self._error_count = 0 + self._success_count = 0 + self._last_call_time = None + + @abstractmethod + def _transform_query(self, method: str, **kwargs) -> Dict[str, Any]: + """Transform input parameters to vendor-specific format. + + Args: + method: The method being called + **kwargs: Input parameters + + Returns: + Transformed query parameters + """ + pass + + @abstractmethod + def _extract_data(self, method: str, query: Dict[str, Any]) -> Any: + """Extract raw data from the vendor. + + Args: + method: The method being called + query: Transformed query parameters + + Returns: + Raw data from the vendor + """ + pass + + @abstractmethod + def _transform_data( + self, + method: str, + raw_data: Any, + query: Dict[str, Any] + ) -> Any: + """Transform raw data to standardized format. + + Args: + method: The method being called + raw_data: Raw data from extract_data + query: Original query parameters + + Returns: + Transformed data in standard format + """ + pass + + def _should_retry(self, exception: Exception) -> bool: + """Determine if operation should be retried for given exception. + + Override this method to customize retry logic. + + Args: + exception: The exception that occurred + + Returns: + True if operation should be retried + """ + # Default: retry on network-related errors + retryable_types = ( + ConnectionError, + TimeoutError, + OSError, + ) + return isinstance(exception, retryable_types) + + def execute( + self, + method: str, + **kwargs + ) -> VendorResponse: + """Execute a vendor method with full pipeline. + + Runs the 3-stage pipeline with retry logic: + 1. transform_query + 2. extract_data (with retries) + 3. transform_data + + Args: + method: Name of the method to execute + **kwargs: Parameters for the method + + Returns: + VendorResponse with result or error information + """ + start_time = time.time() + + # Increment call count (thread-safe) + with self._call_count_lock: + self._call_count += 1 + self._last_call_time = datetime.now() + + response = VendorResponse( + vendor_name=self.name, + method_name=method + ) + + try: + # Stage 1: Transform query + query = self._transform_query(method, **kwargs) + logger.debug(f"[{self.name}] Transformed query for {method}: {query}") + + # Stage 2: Extract data with retries + raw_data = self._extract_with_retry(method, query) + + # Stage 3: Transform data + transformed = self._transform_data(method, raw_data, query) + + response.data = transformed + response.success = True + self._success_count += 1 + + except Exception as e: + response.success = False + response.error_message = str(e) + self._error_count += 1 + logger.error(f"[{self.name}] Error executing {method}: {e}") + + finally: + response.execution_time_ms = (time.time() - start_time) * 1000 + + return response + + def _extract_with_retry(self, method: str, query: Dict[str, Any]) -> Any: + """Execute extract_data with retry logic. + + Args: + method: Method name + query: Query parameters + + Returns: + Raw data from vendor + + Raises: + Last exception if all retries fail + """ + last_exception: Optional[Exception] = None + delay = self._retry_delay + + for attempt in range(self._max_retries + 1): + try: + return self._extract_data(method, query) + + except Exception as e: + last_exception = e + + if not self._should_retry(e): + logger.debug(f"[{self.name}] Non-retryable error: {e}") + raise + + if attempt < self._max_retries: + logger.warning( + f"[{self.name}] Retry {attempt + 1}/{self._max_retries} " + f"after error: {e}. Waiting {delay:.1f}s" + ) + time.sleep(delay) + delay *= self._retry_backoff + + # All retries exhausted + raise last_exception # type: ignore + + +class SimpleVendor(BaseVendor): + """Simplified vendor for wrapping existing functions. + + Useful for migrating existing vendor implementations to the new pattern + without major refactoring. + + Example: + vendor = SimpleVendor( + vendor_name="yfinance", + methods={ + "get_stock_data": get_yfinance_stock, + "get_fundamentals": get_yfinance_fundamentals, + } + ) + """ + + def __init__( + self, + vendor_name: str, + methods: Dict[str, callable], + **kwargs + ): + """Initialize simple vendor with existing functions. + + Args: + vendor_name: Name of the vendor + methods: Dictionary mapping method names to callables + **kwargs: Additional BaseVendor configuration + """ + super().__init__(**kwargs) + self._vendor_name = vendor_name + self._methods = methods + + @property + def name(self) -> str: + """Return the vendor name.""" + return self._vendor_name + + def _transform_query(self, method: str, **kwargs) -> Dict[str, Any]: + """Pass through query parameters unchanged.""" + return kwargs + + def _extract_data(self, method: str, query: Dict[str, Any]) -> Any: + """Call the wrapped function.""" + if method not in self._methods: + raise ValueError(f"Method '{method}' not found in vendor '{self.name}'") + + func = self._methods[method] + return func(**query) + + def _transform_data( + self, + method: str, + raw_data: Any, + query: Dict[str, Any] + ) -> Any: + """Return data unchanged.""" + return raw_data + + def add_method(self, method_name: str, func: callable) -> None: + """Add a method to the vendor. + + Args: + method_name: Name of the method + func: Callable to execute + """ + self._methods[method_name] = func + + def get_methods(self) -> List[str]: + """Get list of available methods.""" + return list(self._methods.keys()) diff --git a/tradingagents/dataflows/vendor_decorators.py b/tradingagents/dataflows/vendor_decorators.py new file mode 100644 index 00000000..ab295b8c --- /dev/null +++ b/tradingagents/dataflows/vendor_decorators.py @@ -0,0 +1,351 @@ +"""Decorators for vendor registration and method management. + +Provides convenient decorators for registering vendors and methods with +the global registry. + +Issue #11: [DATA-10] Interface routing - add new data vendors +""" + +from functools import wraps +from typing import Callable, Optional, Set, Type +import threading +import time +import logging + +from .vendor_registry import ( + VendorCapability, + VendorMetadata, + VendorRegistry, + get_registry +) + +logger = logging.getLogger(__name__) + + +def register_vendor( + name: str, + priority: int = 100, + capabilities: Optional[Set[VendorCapability]] = None, + rate_limit_exception: Optional[Type[Exception]] = None, + description: str = "" +) -> Callable: + """Class decorator to register a vendor with the global registry. + + Example: + @register_vendor( + name="yfinance", + priority=10, + capabilities={VendorCapability.STOCK_DATA, VendorCapability.FUNDAMENTALS} + ) + class YFinanceVendor(BaseVendor): + pass + """ + def decorator(cls: Type) -> Type: + metadata = VendorMetadata( + name=name, + priority=priority, + capabilities=capabilities or set(), + rate_limit_exception=rate_limit_exception, + description=description + ) + + # Register vendor on class definition + get_registry().register_vendor(metadata) + + # Store metadata on class for reference + cls._vendor_metadata = metadata + cls._vendor_name = name + + return cls + + return decorator + + +def vendor_method( + method_name: str, + vendor_name: Optional[str] = None +) -> Callable: + """Decorator to register a method with the vendor registry. + + Can be used as a method decorator on vendor classes or as a + standalone function decorator. + + Example (on class method): + class YFinanceVendor(BaseVendor): + @vendor_method("get_stock_data") + def get_stock(self, ticker): + pass + + Example (standalone): + @vendor_method("get_stock_data", vendor_name="yfinance") + def get_yfinance_stock(ticker): + pass + """ + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + # Store registration info for lazy registration + wrapper._vendor_method_name = method_name + wrapper._vendor_name = vendor_name + + # If vendor_name provided, register immediately + if vendor_name: + try: + get_registry().register_method(vendor_name, method_name, wrapper) + except ValueError: + # Vendor not yet registered - will be registered later + logger.debug( + f"Deferred registration of {method_name} for {vendor_name}" + ) + + return wrapper + + return decorator + + +class RateLimiter: + """Thread-safe rate limiter with sliding window. + + Tracks request timestamps and enforces rate limits using a sliding window. + """ + + def __init__(self, max_calls: int, period_seconds: float): + """Initialize rate limiter. + + Args: + max_calls: Maximum calls allowed in the period + period_seconds: Time window in seconds + """ + self._max_calls = max_calls + self._period = period_seconds + self._calls: list = [] + self._lock = threading.Lock() + + def acquire(self) -> bool: + """Try to acquire a rate limit slot. + + Returns: + True if request is allowed, False if rate limited + """ + with self._lock: + now = time.time() + cutoff = now - self._period + + # Remove expired entries + self._calls = [t for t in self._calls if t > cutoff] + + # Check if under limit + if len(self._calls) < self._max_calls: + self._calls.append(now) + return True + + return False + + def wait_time(self) -> float: + """Get time to wait before next allowed request. + + Returns: + Seconds to wait (0 if request can proceed) + """ + with self._lock: + now = time.time() + cutoff = now - self._period + + # Remove expired entries + self._calls = [t for t in self._calls if t > cutoff] + + if len(self._calls) < self._max_calls: + return 0.0 + + # Time until oldest call expires + return self._calls[0] + self._period - now + + def reset(self) -> None: + """Reset the rate limiter.""" + with self._lock: + self._calls.clear() + + +# Global rate limiters for vendors +_rate_limiters: dict = {} +_rate_limiter_lock = threading.Lock() + + +def get_rate_limiter( + vendor_name: str, + max_calls: int, + period_seconds: float +) -> RateLimiter: + """Get or create a rate limiter for a vendor. + + Args: + vendor_name: Vendor identifier + max_calls: Max calls per period + period_seconds: Rate limit period + + Returns: + RateLimiter instance + """ + with _rate_limiter_lock: + key = f"{vendor_name}:{max_calls}:{period_seconds}" + if key not in _rate_limiters: + _rate_limiters[key] = RateLimiter(max_calls, period_seconds) + return _rate_limiters[key] + + +def rate_limited( + max_calls: int, + period_seconds: float = 60.0, + vendor_name: Optional[str] = None, + exception_class: Optional[Type[Exception]] = None +) -> Callable: + """Decorator to apply rate limiting to a function. + + Example: + @rate_limited(max_calls=5, period_seconds=60, vendor_name="alpha_vantage") + def get_stock_data(ticker): + pass + """ + def decorator(func: Callable) -> Callable: + # Determine vendor name from function or provided value + limiter_name = vendor_name or getattr(func, '_vendor_name', func.__name__) + + @wraps(func) + def wrapper(*args, **kwargs): + limiter = get_rate_limiter(limiter_name, max_calls, period_seconds) + + if not limiter.acquire(): + wait_seconds = limiter.wait_time() + error_msg = ( + f"Rate limit exceeded for {limiter_name}. " + f"Try again in {wait_seconds:.1f} seconds" + ) + + if exception_class: + raise exception_class(error_msg) + raise RuntimeError(error_msg) + + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def with_retry( + max_retries: int = 3, + retry_delay: float = 1.0, + backoff_multiplier: float = 2.0, + retryable_exceptions: tuple = (ConnectionError, TimeoutError) +) -> Callable: + """Decorator to add retry logic to a function. + + Example: + @with_retry(max_retries=3, retry_delay=1.0) + def fetch_data(): + pass + """ + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + last_exception = None + delay = retry_delay + + for attempt in range(max_retries + 1): + try: + return func(*args, **kwargs) + except retryable_exceptions as e: + last_exception = e + if attempt < max_retries: + logger.warning( + f"Retry {attempt + 1}/{max_retries} for {func.__name__}: {e}" + ) + time.sleep(delay) + delay *= backoff_multiplier + else: + raise + + raise last_exception # type: ignore + + return wrapper + + return decorator + + +def cache_result( + ttl_seconds: float = 300.0, + max_size: int = 100 +) -> Callable: + """Decorator to cache function results. + + Simple TTL-based cache with LRU eviction. + + Example: + @cache_result(ttl_seconds=60) + def get_stock_data(ticker): + pass + """ + def decorator(func: Callable) -> Callable: + cache: dict = {} + cache_lock = threading.Lock() + access_order: list = [] + + @wraps(func) + def wrapper(*args, **kwargs): + # Create cache key from args + key = (args, tuple(sorted(kwargs.items()))) + now = time.time() + + with cache_lock: + # Check cache + if key in cache: + value, timestamp = cache[key] + if now - timestamp < ttl_seconds: + # Update access order + if key in access_order: + access_order.remove(key) + access_order.append(key) + return value + else: + # Expired + del cache[key] + if key in access_order: + access_order.remove(key) + + # Execute function + result = func(*args, **kwargs) + + with cache_lock: + # Evict oldest if at capacity + while len(cache) >= max_size and access_order: + oldest_key = access_order.pop(0) + cache.pop(oldest_key, None) + + # Store result + cache[key] = (result, now) + access_order.append(key) + + return result + + # Add cache control methods + def clear_cache(): + with cache_lock: + cache.clear() + access_order.clear() + + def cache_info(): + with cache_lock: + return { + "size": len(cache), + "max_size": max_size, + "ttl_seconds": ttl_seconds + } + + wrapper.clear_cache = clear_cache + wrapper.cache_info = cache_info + + return wrapper + + return decorator diff --git a/tradingagents/dataflows/vendor_registry.py b/tradingagents/dataflows/vendor_registry.py new file mode 100644 index 00000000..63fb1d61 --- /dev/null +++ b/tradingagents/dataflows/vendor_registry.py @@ -0,0 +1,329 @@ +"""Vendor Registry for extensible data vendor management. + +This module provides a thread-safe registry pattern for managing data vendors, +enabling easy addition of new vendors without modifying core interface code. + +Issue #11: [DATA-10] Interface routing - add new data vendors +""" + +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import Callable, Dict, List, Optional, Set, Any, Type +import threading +import logging + +logger = logging.getLogger(__name__) + + +class VendorCapability(Enum): + """Capabilities that vendors can provide.""" + STOCK_DATA = auto() + TECHNICAL_INDICATORS = auto() + FUNDAMENTALS = auto() + BALANCE_SHEET = auto() + CASHFLOW = auto() + INCOME_STATEMENT = auto() + NEWS = auto() + GLOBAL_NEWS = auto() + INSIDER_SENTIMENT = auto() + INSIDER_TRANSACTIONS = auto() + MACROECONOMIC = auto() + BENCHMARK = auto() + + +@dataclass +class VendorMetadata: + """Metadata about a registered vendor.""" + name: str + priority: int = 100 # Lower = higher priority + capabilities: Set[VendorCapability] = field(default_factory=set) + rate_limit_exception: Optional[Type[Exception]] = None + description: str = "" + enabled: bool = True + + def __hash__(self): + return hash(self.name) + + def __eq__(self, other): + if not isinstance(other, VendorMetadata): + return False + return self.name == other.name + + +class VendorRegistry: + """Thread-safe singleton registry for data vendors. + + Provides centralized vendor registration, lookup, and routing. + Uses double-checked locking for thread-safe singleton pattern. + + Example: + # Register a vendor + registry = VendorRegistry() + registry.register_vendor( + VendorMetadata( + name="yfinance", + priority=10, + capabilities={VendorCapability.STOCK_DATA, VendorCapability.FUNDAMENTALS} + ) + ) + + # Register a method + registry.register_method("yfinance", "get_stock_data", get_yfinance_stock) + + # Get vendors for a capability + vendors = registry.get_vendors_for_capability(VendorCapability.STOCK_DATA) + """ + + _instance: Optional["VendorRegistry"] = None + _lock: threading.Lock = threading.Lock() + _initialized: bool = False + + def __new__(cls) -> "VendorRegistry": + """Thread-safe singleton instantiation with double-checked locking.""" + # Fast path: instance exists + if cls._instance is not None: + return cls._instance + + # Slow path: acquire lock and create instance + with cls._lock: + # Double-check inside lock + if cls._instance is None: + instance = super().__new__(cls) + # Initialize instance attributes before publishing + instance._vendors: Dict[str, VendorMetadata] = {} + instance._methods: Dict[str, Dict[str, Callable]] = {} + instance._capability_index: Dict[VendorCapability, Set[str]] = {} + instance._vendor_lock = threading.RLock() + # Publish instance only after fully initialized + cls._instance = instance + return cls._instance + + def __init__(self): + """Initialize registry (only runs once due to singleton).""" + # Skip if already initialized + if VendorRegistry._initialized: + return + VendorRegistry._initialized = True + + def register_vendor(self, metadata: VendorMetadata) -> None: + """Register a new vendor with its metadata. + + Args: + metadata: VendorMetadata containing vendor info and capabilities + + Raises: + ValueError: If vendor name is empty + """ + if not metadata.name: + raise ValueError("Vendor name cannot be empty") + + with self._vendor_lock: + self._vendors[metadata.name] = metadata + + # Update capability index + for capability in metadata.capabilities: + if capability not in self._capability_index: + self._capability_index[capability] = set() + self._capability_index[capability].add(metadata.name) + + logger.debug(f"Registered vendor: {metadata.name} with capabilities: {metadata.capabilities}") + + def unregister_vendor(self, name: str) -> bool: + """Unregister a vendor. + + Args: + name: Name of vendor to unregister + + Returns: + True if vendor was unregistered, False if not found + """ + with self._vendor_lock: + if name not in self._vendors: + return False + + metadata = self._vendors.pop(name) + + # Remove from capability index + for capability in metadata.capabilities: + if capability in self._capability_index: + self._capability_index[capability].discard(name) + + # Remove all registered methods + if name in self._methods: + del self._methods[name] + + logger.debug(f"Unregistered vendor: {name}") + return True + + def register_method( + self, + vendor_name: str, + method_name: str, + implementation: Callable + ) -> None: + """Register a method implementation for a vendor. + + Args: + vendor_name: Name of the vendor + method_name: Name of the method + implementation: Callable implementation + + Raises: + ValueError: If vendor is not registered or method name is empty + """ + if not method_name: + raise ValueError("Method name cannot be empty") + + with self._vendor_lock: + if vendor_name not in self._vendors: + raise ValueError(f"Vendor '{vendor_name}' not registered") + + if vendor_name not in self._methods: + self._methods[vendor_name] = {} + + self._methods[vendor_name][method_name] = implementation + logger.debug(f"Registered method '{method_name}' for vendor '{vendor_name}'") + + def get_vendor(self, name: str) -> Optional[VendorMetadata]: + """Get vendor metadata by name. + + Args: + name: Vendor name + + Returns: + VendorMetadata if found, None otherwise + """ + with self._vendor_lock: + return self._vendors.get(name) + + def get_all_vendors(self) -> List[VendorMetadata]: + """Get all registered vendors sorted by priority. + + Returns: + List of VendorMetadata sorted by priority (lower = higher) + """ + with self._vendor_lock: + return sorted( + self._vendors.values(), + key=lambda v: v.priority + ) + + def get_vendors_for_capability( + self, + capability: VendorCapability, + only_enabled: bool = True + ) -> List[VendorMetadata]: + """Get all vendors that support a specific capability. + + Args: + capability: The capability to find vendors for + only_enabled: If True, only return enabled vendors + + Returns: + List of VendorMetadata sorted by priority + """ + with self._vendor_lock: + vendor_names = self._capability_index.get(capability, set()) + vendors = [ + self._vendors[name] + for name in vendor_names + if name in self._vendors + ] + + if only_enabled: + vendors = [v for v in vendors if v.enabled] + + return sorted(vendors, key=lambda v: v.priority) + + def get_method( + self, + vendor_name: str, + method_name: str + ) -> Optional[Callable]: + """Get a method implementation for a vendor. + + Args: + vendor_name: Name of the vendor + method_name: Name of the method + + Returns: + Callable if found, None otherwise + """ + with self._vendor_lock: + vendor_methods = self._methods.get(vendor_name, {}) + return vendor_methods.get(method_name) + + def get_methods_for_vendor(self, vendor_name: str) -> Dict[str, Callable]: + """Get all methods registered for a vendor. + + Args: + vendor_name: Name of the vendor + + Returns: + Dictionary mapping method names to implementations + """ + with self._vendor_lock: + return dict(self._methods.get(vendor_name, {})) + + def set_vendor_enabled(self, name: str, enabled: bool) -> bool: + """Enable or disable a vendor. + + Args: + name: Vendor name + enabled: Whether vendor should be enabled + + Returns: + True if vendor was found and updated, False otherwise + """ + with self._vendor_lock: + if name not in self._vendors: + return False + self._vendors[name].enabled = enabled + return True + + def set_vendor_priority(self, name: str, priority: int) -> bool: + """Update a vendor's priority. + + Args: + name: Vendor name + priority: New priority (lower = higher priority) + + Returns: + True if vendor was found and updated, False otherwise + """ + with self._vendor_lock: + if name not in self._vendors: + return False + self._vendors[name].priority = priority + return True + + def clear(self) -> None: + """Clear all registrations. Primarily for testing.""" + with self._vendor_lock: + self._vendors.clear() + self._methods.clear() + self._capability_index.clear() + logger.debug("Cleared all vendor registrations") + + @classmethod + def reset_instance(cls) -> None: + """Reset the singleton instance. For testing only.""" + with cls._lock: + cls._instance = None + cls._initialized = False + + +# Module-level convenience functions +def get_registry() -> VendorRegistry: + """Get the global vendor registry instance.""" + return VendorRegistry() + + +def register_vendor(metadata: VendorMetadata) -> None: + """Register a vendor in the global registry.""" + get_registry().register_vendor(metadata) + + +def register_method(vendor_name: str, method_name: str, implementation: Callable) -> None: + """Register a method in the global registry.""" + get_registry().register_method(vendor_name, method_name, implementation)