feat(dataflows): add vendor registry pattern for extensible data vendor routing - Fixes #11
Implements [DATA-10] Interface routing - add new data vendors with: - VendorRegistry: Thread-safe singleton for centralized vendor registration - VendorCapability enum: STOCK_DATA, FUNDAMENTALS, NEWS, MACROECONOMIC, etc. - BaseVendor ABC: 3-stage lifecycle (transform_query, extract_data, transform_data) - SimpleVendor: Wrapper for migrating existing vendor functions - Decorators: @register_vendor, @vendor_method, @rate_limited, @with_retry, @cache_result - RateLimiter: Thread-safe sliding window rate limiting - 84 tests covering registry, base vendor, and decorators Files: - tradingagents/dataflows/vendor_registry.py (253 lines) - tradingagents/dataflows/base_vendor.py (222 lines) - tradingagents/dataflows/vendor_decorators.py (188 lines) - tests/unit/dataflows/test_vendor_registry.py (30 tests) - tests/unit/dataflows/test_base_vendor.py (27 tests) - tests/unit/dataflows/test_vendor_decorators.py (27 tests) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
bbd85c91b6
commit
2c802647e4
|
|
@ -0,0 +1,424 @@
|
||||||
|
"""Tests for BaseVendor abstract class.
|
||||||
|
|
||||||
|
Issue #11: [DATA-10] Interface routing - add new data vendors
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from typing import Dict, Any
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from tradingagents.dataflows.base_vendor import (
|
||||||
|
BaseVendor,
|
||||||
|
SimpleVendor,
|
||||||
|
VendorResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class ConcreteVendor(BaseVendor):
|
||||||
|
"""Concrete implementation for testing."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._vendor_name = "test_vendor"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._vendor_name
|
||||||
|
|
||||||
|
def _transform_query(self, method: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
return {"method": method, **kwargs}
|
||||||
|
|
||||||
|
def _extract_data(self, method: str, query: Dict[str, Any]) -> Any:
|
||||||
|
return {"raw": "data", "query": query}
|
||||||
|
|
||||||
|
def _transform_data(self, method: str, raw_data: Any, query: Dict[str, Any]) -> Any:
|
||||||
|
return {"transformed": raw_data}
|
||||||
|
|
||||||
|
|
||||||
|
class TestVendorResponse:
|
||||||
|
"""Tests for VendorResponse dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test default values are set correctly."""
|
||||||
|
response = VendorResponse()
|
||||||
|
assert response.data is None
|
||||||
|
assert response.success is True
|
||||||
|
assert response.vendor_name == ""
|
||||||
|
assert response.method_name == ""
|
||||||
|
assert response.execution_time_ms == 0.0
|
||||||
|
assert response.error_message == ""
|
||||||
|
assert response.metadata == {}
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
"""Test custom values are preserved."""
|
||||||
|
response = VendorResponse(
|
||||||
|
data={"test": "data"},
|
||||||
|
success=True,
|
||||||
|
vendor_name="yfinance",
|
||||||
|
method_name="get_stock",
|
||||||
|
execution_time_ms=150.5,
|
||||||
|
metadata={"source": "api"}
|
||||||
|
)
|
||||||
|
assert response.data == {"test": "data"}
|
||||||
|
assert response.vendor_name == "yfinance"
|
||||||
|
assert response.method_name == "get_stock"
|
||||||
|
assert response.execution_time_ms == 150.5
|
||||||
|
|
||||||
|
def test_failure_response(self):
|
||||||
|
"""Test failed response."""
|
||||||
|
response = VendorResponse(
|
||||||
|
success=False,
|
||||||
|
error_message="API Error"
|
||||||
|
)
|
||||||
|
assert response.success is False
|
||||||
|
assert response.error_message == "API Error"
|
||||||
|
|
||||||
|
def test_is_empty_with_none(self):
|
||||||
|
"""Test is_empty returns True for None data."""
|
||||||
|
response = VendorResponse(data=None)
|
||||||
|
assert response.is_empty is True
|
||||||
|
|
||||||
|
def test_is_empty_with_empty_list(self):
|
||||||
|
"""Test is_empty returns True for empty list."""
|
||||||
|
response = VendorResponse(data=[])
|
||||||
|
assert response.is_empty is True
|
||||||
|
|
||||||
|
def test_is_empty_with_empty_dict(self):
|
||||||
|
"""Test is_empty returns True for empty dict."""
|
||||||
|
response = VendorResponse(data={})
|
||||||
|
assert response.is_empty is True
|
||||||
|
|
||||||
|
def test_is_empty_with_empty_string(self):
|
||||||
|
"""Test is_empty returns True for empty string."""
|
||||||
|
response = VendorResponse(data="")
|
||||||
|
assert response.is_empty is True
|
||||||
|
|
||||||
|
def test_is_empty_with_data(self):
|
||||||
|
"""Test is_empty returns False when data exists."""
|
||||||
|
response = VendorResponse(data={"test": "data"})
|
||||||
|
assert response.is_empty is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseVendorAbstract:
|
||||||
|
"""Tests for BaseVendor abstract method enforcement."""
|
||||||
|
|
||||||
|
def test_cannot_instantiate_base_vendor(self):
|
||||||
|
"""Test that BaseVendor cannot be instantiated directly."""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
BaseVendor()
|
||||||
|
|
||||||
|
def test_can_instantiate_concrete_vendor(self):
|
||||||
|
"""Test that concrete implementation can be instantiated."""
|
||||||
|
vendor = ConcreteVendor()
|
||||||
|
assert vendor is not None
|
||||||
|
assert vendor.name == "test_vendor"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseVendorConfiguration:
|
||||||
|
"""Tests for BaseVendor configuration."""
|
||||||
|
|
||||||
|
def test_default_configuration(self):
|
||||||
|
"""Test default configuration values."""
|
||||||
|
vendor = ConcreteVendor()
|
||||||
|
assert vendor._max_retries == 3
|
||||||
|
assert vendor._retry_delay == 1.0
|
||||||
|
assert vendor._retry_backoff == 2.0
|
||||||
|
assert vendor._timeout is None
|
||||||
|
|
||||||
|
def test_custom_configuration(self):
|
||||||
|
"""Test custom configuration values."""
|
||||||
|
vendor = ConcreteVendor(
|
||||||
|
max_retries=5,
|
||||||
|
retry_delay=0.5,
|
||||||
|
retry_backoff=3.0,
|
||||||
|
timeout=30.0
|
||||||
|
)
|
||||||
|
assert vendor._max_retries == 5
|
||||||
|
assert vendor._retry_delay == 0.5
|
||||||
|
assert vendor._retry_backoff == 3.0
|
||||||
|
assert vendor._timeout == 30.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseVendorExecution:
|
||||||
|
"""Tests for BaseVendor execute method."""
|
||||||
|
|
||||||
|
def test_successful_execution(self):
|
||||||
|
"""Test successful execution returns VendorResponse."""
|
||||||
|
vendor = ConcreteVendor()
|
||||||
|
response = vendor.execute("get_data", ticker="AAPL")
|
||||||
|
|
||||||
|
assert isinstance(response, VendorResponse)
|
||||||
|
assert response.success is True
|
||||||
|
assert response.vendor_name == "test_vendor"
|
||||||
|
assert response.method_name == "get_data"
|
||||||
|
assert response.execution_time_ms > 0
|
||||||
|
|
||||||
|
def test_execution_increments_call_count(self):
|
||||||
|
"""Test that execute increments call count."""
|
||||||
|
vendor = ConcreteVendor()
|
||||||
|
assert vendor.call_count == 0
|
||||||
|
|
||||||
|
vendor.execute("method1")
|
||||||
|
assert vendor.call_count == 1
|
||||||
|
|
||||||
|
vendor.execute("method2")
|
||||||
|
assert vendor.call_count == 2
|
||||||
|
|
||||||
|
def test_call_count_thread_safe(self):
|
||||||
|
"""Test that call_count is thread-safe."""
|
||||||
|
vendor = ConcreteVendor()
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def execute_many():
|
||||||
|
try:
|
||||||
|
for _ in range(100):
|
||||||
|
vendor.execute("test")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=execute_many) for _ in range(5)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert len(errors) == 0
|
||||||
|
assert vendor.call_count == 500
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseVendorRetry:
|
||||||
|
"""Tests for BaseVendor retry logic."""
|
||||||
|
|
||||||
|
def test_retry_on_connection_error(self):
|
||||||
|
"""Test retry on connection error."""
|
||||||
|
attempt_count = [0]
|
||||||
|
|
||||||
|
class RetryVendor(ConcreteVendor):
|
||||||
|
def _extract_data(self, method, query):
|
||||||
|
attempt_count[0] += 1
|
||||||
|
if attempt_count[0] < 3:
|
||||||
|
raise ConnectionError("Network error")
|
||||||
|
return {"data": "success"}
|
||||||
|
|
||||||
|
vendor = RetryVendor(retry_delay=0.01)
|
||||||
|
response = vendor.execute("test")
|
||||||
|
|
||||||
|
assert attempt_count[0] == 3
|
||||||
|
assert response.success is True
|
||||||
|
|
||||||
|
def test_no_retry_on_value_error(self):
|
||||||
|
"""Test no retry on non-retryable error."""
|
||||||
|
class NoRetryVendor(ConcreteVendor):
|
||||||
|
def _extract_data(self, method, query):
|
||||||
|
raise ValueError("Invalid input")
|
||||||
|
|
||||||
|
vendor = NoRetryVendor(retry_delay=0.01)
|
||||||
|
response = vendor.execute("test")
|
||||||
|
|
||||||
|
assert response.success is False
|
||||||
|
assert "Invalid input" in response.error_message
|
||||||
|
|
||||||
|
def test_max_retries_exhausted(self):
|
||||||
|
"""Test that max retries is respected."""
|
||||||
|
attempt_count = [0]
|
||||||
|
|
||||||
|
class AlwaysFailVendor(ConcreteVendor):
|
||||||
|
def _extract_data(self, method, query):
|
||||||
|
attempt_count[0] += 1
|
||||||
|
raise ConnectionError("Network error")
|
||||||
|
|
||||||
|
vendor = AlwaysFailVendor(max_retries=3, retry_delay=0.01)
|
||||||
|
response = vendor.execute("test")
|
||||||
|
|
||||||
|
assert attempt_count[0] == 4 # 1 initial + 3 retries
|
||||||
|
assert response.success is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseVendorStatistics:
|
||||||
|
"""Tests for BaseVendor statistics."""
|
||||||
|
|
||||||
|
def test_error_rate_calculation(self):
|
||||||
|
"""Test error rate calculation."""
|
||||||
|
class MixedVendor(ConcreteVendor):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._fail_count = 0
|
||||||
|
|
||||||
|
def _extract_data(self, method, query):
|
||||||
|
self._fail_count += 1
|
||||||
|
if self._fail_count <= 2:
|
||||||
|
raise ValueError("Error")
|
||||||
|
return {"data": "success"}
|
||||||
|
|
||||||
|
vendor = MixedVendor(max_retries=0, retry_delay=0.01)
|
||||||
|
|
||||||
|
# 2 failures
|
||||||
|
vendor.execute("test")
|
||||||
|
vendor.execute("test")
|
||||||
|
# 1 success
|
||||||
|
vendor.execute("test")
|
||||||
|
|
||||||
|
# 2 errors out of 3 = 66.67%
|
||||||
|
assert 66 < vendor.error_rate < 67
|
||||||
|
|
||||||
|
def test_reset_stats(self):
|
||||||
|
"""Test reset_stats clears counters."""
|
||||||
|
vendor = ConcreteVendor()
|
||||||
|
vendor.execute("test")
|
||||||
|
vendor.execute("test")
|
||||||
|
|
||||||
|
assert vendor.call_count == 2
|
||||||
|
|
||||||
|
vendor.reset_stats()
|
||||||
|
|
||||||
|
assert vendor.call_count == 0
|
||||||
|
assert vendor.error_rate == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimpleVendor:
|
||||||
|
"""Tests for SimpleVendor wrapper class."""
|
||||||
|
|
||||||
|
def test_simple_vendor_creation(self):
|
||||||
|
"""Test creating a SimpleVendor."""
|
||||||
|
mock_func = Mock(return_value={"data": "test"})
|
||||||
|
vendor = SimpleVendor(
|
||||||
|
vendor_name="test",
|
||||||
|
methods={"get_data": mock_func}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert vendor.name == "test"
|
||||||
|
|
||||||
|
def test_simple_vendor_execution(self):
|
||||||
|
"""Test executing a SimpleVendor method."""
|
||||||
|
mock_func = Mock(return_value={"data": "test"})
|
||||||
|
vendor = SimpleVendor(
|
||||||
|
vendor_name="test",
|
||||||
|
methods={"get_data": mock_func}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = vendor.execute("get_data", ticker="AAPL")
|
||||||
|
|
||||||
|
assert response.success is True
|
||||||
|
assert response.data == {"data": "test"}
|
||||||
|
mock_func.assert_called_once_with(ticker="AAPL")
|
||||||
|
|
||||||
|
def test_simple_vendor_missing_method(self):
|
||||||
|
"""Test SimpleVendor with missing method."""
|
||||||
|
vendor = SimpleVendor(
|
||||||
|
vendor_name="test",
|
||||||
|
methods={}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = vendor.execute("nonexistent")
|
||||||
|
|
||||||
|
assert response.success is False
|
||||||
|
assert "not found" in response.error_message
|
||||||
|
|
||||||
|
def test_simple_vendor_add_method(self):
|
||||||
|
"""Test adding a method to SimpleVendor."""
|
||||||
|
vendor = SimpleVendor(
|
||||||
|
vendor_name="test",
|
||||||
|
methods={}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_func = Mock(return_value="result")
|
||||||
|
vendor.add_method("new_method", mock_func)
|
||||||
|
|
||||||
|
response = vendor.execute("new_method")
|
||||||
|
|
||||||
|
assert response.success is True
|
||||||
|
assert response.data == "result"
|
||||||
|
|
||||||
|
def test_simple_vendor_get_methods(self):
|
||||||
|
"""Test getting list of methods."""
|
||||||
|
vendor = SimpleVendor(
|
||||||
|
vendor_name="test",
|
||||||
|
methods={
|
||||||
|
"method1": Mock(),
|
||||||
|
"method2": Mock(),
|
||||||
|
"method3": Mock()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
methods = vendor.get_methods()
|
||||||
|
|
||||||
|
assert len(methods) == 3
|
||||||
|
assert "method1" in methods
|
||||||
|
assert "method2" in methods
|
||||||
|
assert "method3" in methods
|
||||||
|
|
||||||
|
|
||||||
|
class TestVendorLifecycle:
|
||||||
|
"""Tests for 3-stage vendor lifecycle."""
|
||||||
|
|
||||||
|
def test_lifecycle_order(self):
|
||||||
|
"""Test that lifecycle stages execute in order."""
|
||||||
|
call_order = []
|
||||||
|
|
||||||
|
class OrderVendor(BaseVendor):
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return "order_test"
|
||||||
|
|
||||||
|
def _transform_query(self, method, **kwargs):
|
||||||
|
call_order.append("transform_query")
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def _extract_data(self, method, query):
|
||||||
|
call_order.append("extract_data")
|
||||||
|
return {"raw": True}
|
||||||
|
|
||||||
|
def _transform_data(self, method, raw_data, query):
|
||||||
|
call_order.append("transform_data")
|
||||||
|
return raw_data
|
||||||
|
|
||||||
|
vendor = OrderVendor()
|
||||||
|
vendor.execute("test")
|
||||||
|
|
||||||
|
assert call_order == ["transform_query", "extract_data", "transform_data"]
|
||||||
|
|
||||||
|
def test_query_passed_to_extract(self):
|
||||||
|
"""Test that transformed query is passed to extract."""
|
||||||
|
class QueryVendor(BaseVendor):
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return "query_test"
|
||||||
|
|
||||||
|
def _transform_query(self, method, **kwargs):
|
||||||
|
return {"ticker": kwargs.get("ticker", "").upper()}
|
||||||
|
|
||||||
|
def _extract_data(self, method, query):
|
||||||
|
return {"symbol": query["ticker"]}
|
||||||
|
|
||||||
|
def _transform_data(self, method, raw_data, query):
|
||||||
|
return raw_data
|
||||||
|
|
||||||
|
vendor = QueryVendor()
|
||||||
|
response = vendor.execute("test", ticker="aapl")
|
||||||
|
|
||||||
|
assert response.data["symbol"] == "AAPL"
|
||||||
|
|
||||||
|
def test_raw_data_passed_to_transform(self):
|
||||||
|
"""Test that raw data is passed to transform."""
|
||||||
|
class TransformVendor(BaseVendor):
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return "transform_test"
|
||||||
|
|
||||||
|
def _transform_query(self, method, **kwargs):
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def _extract_data(self, method, query):
|
||||||
|
return {"raw_value": 42}
|
||||||
|
|
||||||
|
def _transform_data(self, method, raw_data, query):
|
||||||
|
return {"processed": raw_data["raw_value"] * 2}
|
||||||
|
|
||||||
|
vendor = TransformVendor()
|
||||||
|
response = vendor.execute("test")
|
||||||
|
|
||||||
|
assert response.data["processed"] == 84
|
||||||
|
|
@ -0,0 +1,416 @@
|
||||||
|
"""Tests for vendor decorators.
|
||||||
|
|
||||||
|
Issue #11: [DATA-10] Interface routing - add new data vendors
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from tradingagents.dataflows.vendor_registry import (
|
||||||
|
VendorCapability,
|
||||||
|
VendorMetadata,
|
||||||
|
VendorRegistry,
|
||||||
|
)
|
||||||
|
from tradingagents.dataflows.vendor_decorators import (
|
||||||
|
register_vendor,
|
||||||
|
vendor_method,
|
||||||
|
rate_limited,
|
||||||
|
with_retry,
|
||||||
|
cache_result,
|
||||||
|
RateLimiter,
|
||||||
|
get_rate_limiter,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_registry():
|
||||||
|
"""Reset the singleton registry before each test."""
|
||||||
|
VendorRegistry.reset_instance()
|
||||||
|
yield
|
||||||
|
VendorRegistry.reset_instance()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegisterVendorDecorator:
|
||||||
|
"""Tests for @register_vendor decorator."""
|
||||||
|
|
||||||
|
def test_register_vendor_basic(self):
|
||||||
|
"""Test basic vendor registration with decorator."""
|
||||||
|
@register_vendor(
|
||||||
|
name="test_vendor",
|
||||||
|
capabilities={VendorCapability.STOCK_DATA},
|
||||||
|
priority=10
|
||||||
|
)
|
||||||
|
class TestVendor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert hasattr(TestVendor, '_vendor_metadata')
|
||||||
|
assert TestVendor._vendor_name == "test_vendor"
|
||||||
|
|
||||||
|
def test_register_vendor_adds_to_registry(self):
|
||||||
|
"""Test that decorator adds vendor to registry."""
|
||||||
|
@register_vendor(
|
||||||
|
name="registered_vendor",
|
||||||
|
capabilities={VendorCapability.STOCK_DATA}
|
||||||
|
)
|
||||||
|
class RegisteredVendor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
registry = VendorRegistry()
|
||||||
|
assert registry.get_vendor("registered_vendor") is not None
|
||||||
|
|
||||||
|
def test_register_vendor_with_multiple_capabilities(self):
|
||||||
|
"""Test registering vendor with multiple capabilities."""
|
||||||
|
@register_vendor(
|
||||||
|
name="multi_vendor",
|
||||||
|
capabilities={
|
||||||
|
VendorCapability.STOCK_DATA,
|
||||||
|
VendorCapability.FUNDAMENTALS,
|
||||||
|
VendorCapability.NEWS
|
||||||
|
}
|
||||||
|
)
|
||||||
|
class MultiVendor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
registry = VendorRegistry()
|
||||||
|
vendor = registry.get_vendor("multi_vendor")
|
||||||
|
assert len(vendor.capabilities) == 3
|
||||||
|
|
||||||
|
def test_register_vendor_preserves_class_methods(self):
|
||||||
|
"""Test that decorator preserves class methods."""
|
||||||
|
@register_vendor(
|
||||||
|
name="method_vendor",
|
||||||
|
capabilities={VendorCapability.STOCK_DATA}
|
||||||
|
)
|
||||||
|
class MethodVendor:
|
||||||
|
def fetch_data(self):
|
||||||
|
return "data"
|
||||||
|
|
||||||
|
vendor = MethodVendor()
|
||||||
|
assert vendor.fetch_data() == "data"
|
||||||
|
|
||||||
|
|
||||||
|
class TestVendorMethodDecorator:
|
||||||
|
"""Tests for @vendor_method decorator."""
|
||||||
|
|
||||||
|
def test_vendor_method_basic(self):
|
||||||
|
"""Test basic method mapping."""
|
||||||
|
class TestVendor:
|
||||||
|
@vendor_method("get_stock_data")
|
||||||
|
def fetch_stock(self, ticker):
|
||||||
|
return f"Stock: {ticker}"
|
||||||
|
|
||||||
|
vendor = TestVendor()
|
||||||
|
result = vendor.fetch_stock("AAPL")
|
||||||
|
|
||||||
|
assert result == "Stock: AAPL"
|
||||||
|
assert vendor.fetch_stock._vendor_method_name == "get_stock_data"
|
||||||
|
|
||||||
|
def test_vendor_method_with_vendor_name(self):
|
||||||
|
"""Test vendor_method with explicit vendor name."""
|
||||||
|
# Register the vendor first
|
||||||
|
registry = VendorRegistry()
|
||||||
|
registry.register_vendor(VendorMetadata(name="yfinance"))
|
||||||
|
|
||||||
|
@vendor_method("get_stock", vendor_name="yfinance")
|
||||||
|
def get_yfinance_stock(ticker):
|
||||||
|
return f"YF: {ticker}"
|
||||||
|
|
||||||
|
# Method should be registered
|
||||||
|
method = registry.get_method("yfinance", "get_stock")
|
||||||
|
assert method is not None
|
||||||
|
|
||||||
|
def test_vendor_method_preserves_function(self):
|
||||||
|
"""Test that decorator preserves function behavior."""
|
||||||
|
class TestVendor:
|
||||||
|
@vendor_method("get_data")
|
||||||
|
def fetch_data(self, ticker, start_date=None):
|
||||||
|
return {"ticker": ticker, "start": start_date}
|
||||||
|
|
||||||
|
vendor = TestVendor()
|
||||||
|
result = vendor.fetch_data("AAPL", start_date="2024-01-01")
|
||||||
|
|
||||||
|
assert result["ticker"] == "AAPL"
|
||||||
|
assert result["start"] == "2024-01-01"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimiter:
|
||||||
|
"""Tests for RateLimiter class."""
|
||||||
|
|
||||||
|
def test_rate_limiter_allows_calls_under_limit(self):
|
||||||
|
"""Test that calls under limit are allowed."""
|
||||||
|
limiter = RateLimiter(max_calls=5, period_seconds=60)
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
assert limiter.acquire() is True
|
||||||
|
|
||||||
|
def test_rate_limiter_blocks_over_limit(self):
|
||||||
|
"""Test that calls over limit are blocked."""
|
||||||
|
limiter = RateLimiter(max_calls=3, period_seconds=60)
|
||||||
|
|
||||||
|
# Use all slots
|
||||||
|
for _ in range(3):
|
||||||
|
limiter.acquire()
|
||||||
|
|
||||||
|
# 4th call should be blocked
|
||||||
|
assert limiter.acquire() is False
|
||||||
|
|
||||||
|
def test_rate_limiter_wait_time(self):
|
||||||
|
"""Test wait_time calculation."""
|
||||||
|
limiter = RateLimiter(max_calls=2, period_seconds=60)
|
||||||
|
|
||||||
|
# Fill up the slots
|
||||||
|
limiter.acquire()
|
||||||
|
limiter.acquire()
|
||||||
|
|
||||||
|
# Should have positive wait time
|
||||||
|
wait = limiter.wait_time()
|
||||||
|
assert wait > 0
|
||||||
|
assert wait <= 60
|
||||||
|
|
||||||
|
def test_rate_limiter_reset(self):
|
||||||
|
"""Test reset clears the limiter."""
|
||||||
|
limiter = RateLimiter(max_calls=2, period_seconds=60)
|
||||||
|
|
||||||
|
limiter.acquire()
|
||||||
|
limiter.acquire()
|
||||||
|
assert limiter.acquire() is False
|
||||||
|
|
||||||
|
limiter.reset()
|
||||||
|
assert limiter.acquire() is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitedDecorator:
|
||||||
|
"""Tests for @rate_limited decorator."""
|
||||||
|
|
||||||
|
def test_rate_limited_allows_calls_under_limit(self):
|
||||||
|
"""Test that calls under limit succeed."""
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
@rate_limited(max_calls=5, period_seconds=60)
|
||||||
|
def limited_func():
|
||||||
|
call_count[0] += 1
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
result = limited_func()
|
||||||
|
assert result == "success"
|
||||||
|
|
||||||
|
assert call_count[0] == 5
|
||||||
|
|
||||||
|
def test_rate_limited_blocks_over_limit(self):
|
||||||
|
"""Test that calls over limit raise error."""
|
||||||
|
@rate_limited(max_calls=2, period_seconds=60)
|
||||||
|
def limited_func():
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
limited_func()
|
||||||
|
limited_func()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Rate limit exceeded"):
|
||||||
|
limited_func()
|
||||||
|
|
||||||
|
def test_rate_limited_custom_exception(self):
|
||||||
|
"""Test rate limiting with custom exception class."""
|
||||||
|
class CustomRateLimitError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@rate_limited(
|
||||||
|
max_calls=1,
|
||||||
|
period_seconds=60,
|
||||||
|
exception_class=CustomRateLimitError
|
||||||
|
)
|
||||||
|
def limited_func():
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
limited_func()
|
||||||
|
|
||||||
|
with pytest.raises(CustomRateLimitError):
|
||||||
|
limited_func()
|
||||||
|
|
||||||
|
|
||||||
|
class TestWithRetryDecorator:
|
||||||
|
"""Tests for @with_retry decorator."""
|
||||||
|
|
||||||
|
def test_retry_on_retryable_error(self):
|
||||||
|
"""Test retry on retryable exceptions."""
|
||||||
|
attempt_count = [0]
|
||||||
|
|
||||||
|
@with_retry(max_retries=3, retry_delay=0.01)
|
||||||
|
def failing_func():
|
||||||
|
attempt_count[0] += 1
|
||||||
|
if attempt_count[0] < 3:
|
||||||
|
raise ConnectionError("Network error")
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
result = failing_func()
|
||||||
|
assert result == "success"
|
||||||
|
assert attempt_count[0] == 3
|
||||||
|
|
||||||
|
def test_no_retry_on_non_retryable_error(self):
|
||||||
|
"""Test no retry on non-retryable exceptions."""
|
||||||
|
attempt_count = [0]
|
||||||
|
|
||||||
|
@with_retry(
|
||||||
|
max_retries=3,
|
||||||
|
retry_delay=0.01,
|
||||||
|
retryable_exceptions=(ConnectionError,)
|
||||||
|
)
|
||||||
|
def failing_func():
|
||||||
|
attempt_count[0] += 1
|
||||||
|
raise ValueError("Not retryable")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
failing_func()
|
||||||
|
|
||||||
|
assert attempt_count[0] == 1
|
||||||
|
|
||||||
|
def test_retry_exhausted(self):
|
||||||
|
"""Test exception raised after max retries."""
|
||||||
|
@with_retry(max_retries=2, retry_delay=0.01)
|
||||||
|
def always_fails():
|
||||||
|
raise ConnectionError("Always fails")
|
||||||
|
|
||||||
|
with pytest.raises(ConnectionError):
|
||||||
|
always_fails()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheResultDecorator:
|
||||||
|
"""Tests for @cache_result decorator."""
|
||||||
|
|
||||||
|
def test_cache_returns_cached_value(self):
|
||||||
|
"""Test that cached value is returned."""
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
@cache_result(ttl_seconds=300)
|
||||||
|
def expensive_func(key):
|
||||||
|
call_count[0] += 1
|
||||||
|
return f"value_{key}"
|
||||||
|
|
||||||
|
# First call
|
||||||
|
result1 = expensive_func("test")
|
||||||
|
# Second call (should use cache)
|
||||||
|
result2 = expensive_func("test")
|
||||||
|
|
||||||
|
assert result1 == result2
|
||||||
|
assert call_count[0] == 1
|
||||||
|
|
||||||
|
def test_cache_different_keys(self):
|
||||||
|
"""Test that different keys have different cache entries."""
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
@cache_result(ttl_seconds=300)
|
||||||
|
def expensive_func(key):
|
||||||
|
call_count[0] += 1
|
||||||
|
return f"value_{key}"
|
||||||
|
|
||||||
|
result1 = expensive_func("key1")
|
||||||
|
result2 = expensive_func("key2")
|
||||||
|
|
||||||
|
assert result1 != result2
|
||||||
|
assert call_count[0] == 2
|
||||||
|
|
||||||
|
def test_cache_clear(self):
|
||||||
|
"""Test clearing the cache."""
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
@cache_result(ttl_seconds=300)
|
||||||
|
def expensive_func(key):
|
||||||
|
call_count[0] += 1
|
||||||
|
return f"value_{key}"
|
||||||
|
|
||||||
|
expensive_func("test")
|
||||||
|
expensive_func.clear_cache()
|
||||||
|
expensive_func("test")
|
||||||
|
|
||||||
|
assert call_count[0] == 2
|
||||||
|
|
||||||
|
def test_cache_info(self):
|
||||||
|
"""Test cache info method."""
|
||||||
|
@cache_result(ttl_seconds=60, max_size=100)
|
||||||
|
def cached_func(key):
|
||||||
|
return key
|
||||||
|
|
||||||
|
cached_func("test1")
|
||||||
|
cached_func("test2")
|
||||||
|
|
||||||
|
info = cached_func.cache_info()
|
||||||
|
assert info["size"] == 2
|
||||||
|
assert info["max_size"] == 100
|
||||||
|
assert info["ttl_seconds"] == 60
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecoratorStacking:
|
||||||
|
"""Tests for stacking multiple decorators."""
|
||||||
|
|
||||||
|
def test_register_and_method_decorators(self):
|
||||||
|
"""Test stacking @register_vendor and @vendor_method."""
|
||||||
|
@register_vendor(
|
||||||
|
name="stacked_vendor",
|
||||||
|
capabilities={VendorCapability.STOCK_DATA}
|
||||||
|
)
|
||||||
|
class StackedVendor:
|
||||||
|
@vendor_method("get_stock_data")
|
||||||
|
def fetch_stock(self, ticker):
|
||||||
|
return f"Stock: {ticker}"
|
||||||
|
|
||||||
|
vendor = StackedVendor()
|
||||||
|
|
||||||
|
# Vendor should be registered
|
||||||
|
registry = VendorRegistry()
|
||||||
|
assert registry.get_vendor("stacked_vendor") is not None
|
||||||
|
|
||||||
|
# Method should work
|
||||||
|
result = vendor.fetch_stock("AAPL")
|
||||||
|
assert result == "Stock: AAPL"
|
||||||
|
|
||||||
|
def test_method_with_rate_limit(self):
|
||||||
|
"""Test stacking @vendor_method with @rate_limited."""
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
class TestVendor:
|
||||||
|
@vendor_method("get_data")
|
||||||
|
@rate_limited(max_calls=3, period_seconds=60)
|
||||||
|
def fetch_data(self):
|
||||||
|
call_count[0] += 1
|
||||||
|
return "data"
|
||||||
|
|
||||||
|
vendor = TestVendor()
|
||||||
|
|
||||||
|
# Should work 3 times
|
||||||
|
for _ in range(3):
|
||||||
|
vendor.fetch_data()
|
||||||
|
|
||||||
|
# 4th should fail
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
vendor.fetch_data()
|
||||||
|
|
||||||
|
assert call_count[0] == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetRateLimiter:
|
||||||
|
"""Tests for get_rate_limiter function."""
|
||||||
|
|
||||||
|
def test_returns_same_limiter_for_same_params(self):
|
||||||
|
"""Test that same params return same limiter."""
|
||||||
|
limiter1 = get_rate_limiter("vendor", 5, 60)
|
||||||
|
limiter2 = get_rate_limiter("vendor", 5, 60)
|
||||||
|
|
||||||
|
assert limiter1 is limiter2
|
||||||
|
|
||||||
|
def test_returns_different_limiter_for_different_vendor(self):
|
||||||
|
"""Test that different vendors get different limiters."""
|
||||||
|
limiter1 = get_rate_limiter("vendor1", 5, 60)
|
||||||
|
limiter2 = get_rate_limiter("vendor2", 5, 60)
|
||||||
|
|
||||||
|
assert limiter1 is not limiter2
|
||||||
|
|
||||||
|
def test_returns_different_limiter_for_different_params(self):
|
||||||
|
"""Test that different params get different limiters."""
|
||||||
|
limiter1 = get_rate_limiter("vendor", 5, 60)
|
||||||
|
limiter2 = get_rate_limiter("vendor", 10, 60)
|
||||||
|
|
||||||
|
assert limiter1 is not limiter2
|
||||||
|
|
@ -0,0 +1,354 @@
|
||||||
|
"""Tests for VendorRegistry.
|
||||||
|
|
||||||
|
Issue #11: [DATA-10] Interface routing - add new data vendors
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import threading
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from tradingagents.dataflows.vendor_registry import (
|
||||||
|
VendorCapability,
|
||||||
|
VendorMetadata,
|
||||||
|
VendorRegistry,
|
||||||
|
get_registry,
|
||||||
|
register_vendor,
|
||||||
|
register_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_registry():
|
||||||
|
"""Reset the singleton registry before each test."""
|
||||||
|
VendorRegistry.reset_instance()
|
||||||
|
yield
|
||||||
|
VendorRegistry.reset_instance()
|
||||||
|
|
||||||
|
|
||||||
|
class TestVendorMetadata:
|
||||||
|
"""Tests for VendorMetadata dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test default values are set correctly."""
|
||||||
|
metadata = VendorMetadata(name="test")
|
||||||
|
assert metadata.name == "test"
|
||||||
|
assert metadata.priority == 100
|
||||||
|
assert metadata.capabilities == set()
|
||||||
|
assert metadata.rate_limit_exception is None
|
||||||
|
assert metadata.description == ""
|
||||||
|
assert metadata.enabled is True
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
"""Test custom values are preserved."""
|
||||||
|
metadata = VendorMetadata(
|
||||||
|
name="yfinance",
|
||||||
|
priority=10,
|
||||||
|
capabilities={VendorCapability.STOCK_DATA},
|
||||||
|
description="Yahoo Finance vendor"
|
||||||
|
)
|
||||||
|
assert metadata.name == "yfinance"
|
||||||
|
assert metadata.priority == 10
|
||||||
|
assert VendorCapability.STOCK_DATA in metadata.capabilities
|
||||||
|
assert metadata.description == "Yahoo Finance vendor"
|
||||||
|
|
||||||
|
def test_hash(self):
|
||||||
|
"""Test VendorMetadata is hashable by name."""
|
||||||
|
m1 = VendorMetadata(name="vendor1")
|
||||||
|
m2 = VendorMetadata(name="vendor1", priority=50)
|
||||||
|
assert hash(m1) == hash(m2)
|
||||||
|
|
||||||
|
def test_equality(self):
|
||||||
|
"""Test VendorMetadata equality is by name."""
|
||||||
|
m1 = VendorMetadata(name="vendor1", priority=10)
|
||||||
|
m2 = VendorMetadata(name="vendor1", priority=50)
|
||||||
|
m3 = VendorMetadata(name="vendor2")
|
||||||
|
assert m1 == m2
|
||||||
|
assert m1 != m3
|
||||||
|
|
||||||
|
def test_equality_with_non_metadata(self):
|
||||||
|
"""Test equality with non-VendorMetadata objects."""
|
||||||
|
metadata = VendorMetadata(name="test")
|
||||||
|
assert metadata != "test"
|
||||||
|
assert metadata != {"name": "test"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestVendorRegistry:
|
||||||
|
"""Tests for VendorRegistry singleton."""
|
||||||
|
|
||||||
|
def test_singleton_pattern(self):
|
||||||
|
"""Test that VendorRegistry is a singleton."""
|
||||||
|
r1 = VendorRegistry()
|
||||||
|
r2 = VendorRegistry()
|
||||||
|
assert r1 is r2
|
||||||
|
|
||||||
|
def test_get_registry_returns_singleton(self):
|
||||||
|
"""Test get_registry returns same instance."""
|
||||||
|
r1 = get_registry()
|
||||||
|
r2 = VendorRegistry()
|
||||||
|
assert r1 is r2
|
||||||
|
|
||||||
|
def test_reset_instance(self):
|
||||||
|
"""Test reset_instance creates new singleton."""
|
||||||
|
r1 = VendorRegistry()
|
||||||
|
VendorRegistry.reset_instance()
|
||||||
|
r2 = VendorRegistry()
|
||||||
|
assert r1 is not r2
|
||||||
|
|
||||||
|
def test_register_vendor(self):
|
||||||
|
"""Test registering a vendor."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
metadata = VendorMetadata(
|
||||||
|
name="yfinance",
|
||||||
|
capabilities={VendorCapability.STOCK_DATA}
|
||||||
|
)
|
||||||
|
registry.register_vendor(metadata)
|
||||||
|
|
||||||
|
result = registry.get_vendor("yfinance")
|
||||||
|
assert result is not None
|
||||||
|
assert result.name == "yfinance"
|
||||||
|
|
||||||
|
def test_register_vendor_empty_name_raises(self):
|
||||||
|
"""Test registering vendor with empty name raises ValueError."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
registry.register_vendor(VendorMetadata(name=""))
|
||||||
|
|
||||||
|
def test_unregister_vendor(self):
|
||||||
|
"""Test unregistering a vendor."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
metadata = VendorMetadata(name="test_vendor")
|
||||||
|
registry.register_vendor(metadata)
|
||||||
|
|
||||||
|
assert registry.unregister_vendor("test_vendor") is True
|
||||||
|
assert registry.get_vendor("test_vendor") is None
|
||||||
|
|
||||||
|
def test_unregister_nonexistent_vendor(self):
|
||||||
|
"""Test unregistering non-existent vendor returns False."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
assert registry.unregister_vendor("nonexistent") is False
|
||||||
|
|
||||||
|
def test_get_all_vendors_sorted_by_priority(self):
|
||||||
|
"""Test get_all_vendors returns vendors sorted by priority."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
registry.register_vendor(VendorMetadata(name="low", priority=100))
|
||||||
|
registry.register_vendor(VendorMetadata(name="high", priority=10))
|
||||||
|
registry.register_vendor(VendorMetadata(name="medium", priority=50))
|
||||||
|
|
||||||
|
vendors = registry.get_all_vendors()
|
||||||
|
assert [v.name for v in vendors] == ["high", "medium", "low"]
|
||||||
|
|
||||||
|
def test_get_vendors_for_capability(self):
|
||||||
|
"""Test getting vendors by capability."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
registry.register_vendor(VendorMetadata(
|
||||||
|
name="yfinance",
|
||||||
|
capabilities={VendorCapability.STOCK_DATA, VendorCapability.FUNDAMENTALS}
|
||||||
|
))
|
||||||
|
registry.register_vendor(VendorMetadata(
|
||||||
|
name="alpha_vantage",
|
||||||
|
capabilities={VendorCapability.STOCK_DATA}
|
||||||
|
))
|
||||||
|
|
||||||
|
stock_vendors = registry.get_vendors_for_capability(VendorCapability.STOCK_DATA)
|
||||||
|
assert len(stock_vendors) == 2
|
||||||
|
|
||||||
|
fundamental_vendors = registry.get_vendors_for_capability(VendorCapability.FUNDAMENTALS)
|
||||||
|
assert len(fundamental_vendors) == 1
|
||||||
|
assert fundamental_vendors[0].name == "yfinance"
|
||||||
|
|
||||||
|
def test_get_vendors_for_capability_only_enabled(self):
|
||||||
|
"""Test only enabled vendors are returned by default."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
registry.register_vendor(VendorMetadata(
|
||||||
|
name="enabled_vendor",
|
||||||
|
capabilities={VendorCapability.STOCK_DATA}
|
||||||
|
))
|
||||||
|
registry.register_vendor(VendorMetadata(
|
||||||
|
name="disabled_vendor",
|
||||||
|
capabilities={VendorCapability.STOCK_DATA},
|
||||||
|
enabled=False
|
||||||
|
))
|
||||||
|
|
||||||
|
vendors = registry.get_vendors_for_capability(VendorCapability.STOCK_DATA)
|
||||||
|
assert len(vendors) == 1
|
||||||
|
assert vendors[0].name == "enabled_vendor"
|
||||||
|
|
||||||
|
all_vendors = registry.get_vendors_for_capability(
|
||||||
|
VendorCapability.STOCK_DATA,
|
||||||
|
only_enabled=False
|
||||||
|
)
|
||||||
|
assert len(all_vendors) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestVendorRegistryMethods:
|
||||||
|
"""Tests for method registration in VendorRegistry."""
|
||||||
|
|
||||||
|
def test_register_method(self):
|
||||||
|
"""Test registering a method for a vendor."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
registry.register_vendor(VendorMetadata(name="yfinance"))
|
||||||
|
|
||||||
|
mock_func = Mock()
|
||||||
|
registry.register_method("yfinance", "get_stock", mock_func)
|
||||||
|
|
||||||
|
result = registry.get_method("yfinance", "get_stock")
|
||||||
|
assert result is mock_func
|
||||||
|
|
||||||
|
def test_register_method_unregistered_vendor_raises(self):
|
||||||
|
"""Test registering method for unregistered vendor raises."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
with pytest.raises(ValueError, match="not registered"):
|
||||||
|
registry.register_method("nonexistent", "method", Mock())
|
||||||
|
|
||||||
|
def test_register_method_empty_name_raises(self):
|
||||||
|
"""Test registering method with empty name raises."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
registry.register_vendor(VendorMetadata(name="test"))
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
registry.register_method("test", "", Mock())
|
||||||
|
|
||||||
|
def test_get_method_nonexistent(self):
|
||||||
|
"""Test getting non-existent method returns None."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
registry.register_vendor(VendorMetadata(name="test"))
|
||||||
|
assert registry.get_method("test", "nonexistent") is None
|
||||||
|
|
||||||
|
def test_get_methods_for_vendor(self):
|
||||||
|
"""Test getting all methods for a vendor."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
registry.register_vendor(VendorMetadata(name="yfinance"))
|
||||||
|
|
||||||
|
func1 = Mock()
|
||||||
|
func2 = Mock()
|
||||||
|
registry.register_method("yfinance", "get_stock", func1)
|
||||||
|
registry.register_method("yfinance", "get_fundamentals", func2)
|
||||||
|
|
||||||
|
methods = registry.get_methods_for_vendor("yfinance")
|
||||||
|
assert len(methods) == 2
|
||||||
|
assert methods["get_stock"] is func1
|
||||||
|
assert methods["get_fundamentals"] is func2
|
||||||
|
|
||||||
|
def test_get_methods_for_unregistered_vendor(self):
|
||||||
|
"""Test getting methods for unregistered vendor returns empty dict."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
assert registry.get_methods_for_vendor("nonexistent") == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestVendorRegistryVendorControl:
|
||||||
|
"""Tests for vendor enable/disable and priority control."""
|
||||||
|
|
||||||
|
def test_set_vendor_enabled(self):
|
||||||
|
"""Test enabling/disabling a vendor."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
registry.register_vendor(VendorMetadata(name="test", enabled=True))
|
||||||
|
|
||||||
|
assert registry.set_vendor_enabled("test", False) is True
|
||||||
|
assert registry.get_vendor("test").enabled is False
|
||||||
|
|
||||||
|
assert registry.set_vendor_enabled("test", True) is True
|
||||||
|
assert registry.get_vendor("test").enabled is True
|
||||||
|
|
||||||
|
def test_set_vendor_enabled_nonexistent(self):
|
||||||
|
"""Test setting enabled on non-existent vendor."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
assert registry.set_vendor_enabled("nonexistent", True) is False
|
||||||
|
|
||||||
|
def test_set_vendor_priority(self):
|
||||||
|
"""Test updating vendor priority."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
registry.register_vendor(VendorMetadata(name="test", priority=100))
|
||||||
|
|
||||||
|
assert registry.set_vendor_priority("test", 10) is True
|
||||||
|
assert registry.get_vendor("test").priority == 10
|
||||||
|
|
||||||
|
def test_set_vendor_priority_nonexistent(self):
|
||||||
|
"""Test setting priority on non-existent vendor."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
assert registry.set_vendor_priority("nonexistent", 10) is False
|
||||||
|
|
||||||
|
def test_clear(self):
|
||||||
|
"""Test clearing all registrations."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
registry.register_vendor(VendorMetadata(
|
||||||
|
name="test",
|
||||||
|
capabilities={VendorCapability.STOCK_DATA}
|
||||||
|
))
|
||||||
|
registry.register_method("test", "method", Mock())
|
||||||
|
|
||||||
|
registry.clear()
|
||||||
|
|
||||||
|
assert registry.get_vendor("test") is None
|
||||||
|
assert len(registry.get_all_vendors()) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestVendorRegistryThreadSafety:
|
||||||
|
"""Tests for thread safety of VendorRegistry."""
|
||||||
|
|
||||||
|
def test_concurrent_registration(self):
|
||||||
|
"""Test concurrent vendor registration is thread-safe."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def register_vendor(i):
|
||||||
|
try:
|
||||||
|
registry.register_vendor(VendorMetadata(
|
||||||
|
name=f"vendor_{i}",
|
||||||
|
capabilities={VendorCapability.STOCK_DATA}
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=register_vendor, args=(i,)) for i in range(50)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert len(errors) == 0
|
||||||
|
assert len(registry.get_all_vendors()) == 50
|
||||||
|
|
||||||
|
def test_concurrent_method_registration(self):
|
||||||
|
"""Test concurrent method registration is thread-safe."""
|
||||||
|
registry = VendorRegistry()
|
||||||
|
registry.register_vendor(VendorMetadata(name="test"))
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def register_method(i):
|
||||||
|
try:
|
||||||
|
registry.register_method("test", f"method_{i}", Mock())
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=register_method, args=(i,)) for i in range(50)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert len(errors) == 0
|
||||||
|
assert len(registry.get_methods_for_vendor("test")) == 50
|
||||||
|
|
||||||
|
|
||||||
|
class TestModuleFunctions:
|
||||||
|
"""Tests for module-level convenience functions."""
|
||||||
|
|
||||||
|
def test_register_vendor_function(self):
|
||||||
|
"""Test module-level register_vendor function."""
|
||||||
|
metadata = VendorMetadata(name="module_test")
|
||||||
|
register_vendor(metadata)
|
||||||
|
|
||||||
|
registry = get_registry()
|
||||||
|
assert registry.get_vendor("module_test") is not None
|
||||||
|
|
||||||
|
def test_register_method_function(self):
|
||||||
|
"""Test module-level register_method function."""
|
||||||
|
metadata = VendorMetadata(name="method_test")
|
||||||
|
register_vendor(metadata)
|
||||||
|
|
||||||
|
mock_func = Mock()
|
||||||
|
register_method("method_test", "test_method", mock_func)
|
||||||
|
|
||||||
|
registry = get_registry()
|
||||||
|
assert registry.get_method("method_test", "test_method") is mock_func
|
||||||
|
|
@ -0,0 +1,365 @@
|
||||||
|
"""Base vendor abstract class for data provider implementations.
|
||||||
|
|
||||||
|
This module provides the abstract base class that all data vendors must inherit from,
|
||||||
|
implementing a 3-stage data pipeline: transform_query → extract_data → transform_data.
|
||||||
|
|
||||||
|
Issue #11: [DATA-10] Interface routing - add new data vendors
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional, TypeVar, Generic
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VendorResponse(Generic[T]):
|
||||||
|
"""Standardized response from vendor data operations.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
data: The extracted data
|
||||||
|
success: Whether the operation succeeded
|
||||||
|
vendor_name: Name of the vendor that provided the data
|
||||||
|
method_name: Name of the method called
|
||||||
|
execution_time_ms: Time taken in milliseconds
|
||||||
|
error_message: Error message if operation failed
|
||||||
|
metadata: Additional metadata about the response
|
||||||
|
"""
|
||||||
|
data: Optional[T] = None
|
||||||
|
success: bool = True
|
||||||
|
vendor_name: str = ""
|
||||||
|
method_name: str = ""
|
||||||
|
execution_time_ms: float = 0.0
|
||||||
|
error_message: str = ""
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
"""Check if response contains no data."""
|
||||||
|
if self.data is None:
|
||||||
|
return True
|
||||||
|
if isinstance(self.data, (list, dict, str)):
|
||||||
|
return len(self.data) == 0
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVendor(ABC):
|
||||||
|
"""Abstract base class for all data vendors.
|
||||||
|
|
||||||
|
Implements a 3-stage pipeline pattern:
|
||||||
|
1. transform_query: Normalize input parameters
|
||||||
|
2. extract_data: Fetch raw data from source
|
||||||
|
3. transform_data: Normalize output format
|
||||||
|
|
||||||
|
Thread-safe call counting with configurable retry logic.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class YFinanceVendor(BaseVendor):
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "yfinance"
|
||||||
|
|
||||||
|
def _transform_query(self, method, **kwargs):
|
||||||
|
# Normalize ticker symbol
|
||||||
|
return {"symbol": kwargs["ticker"].upper()}
|
||||||
|
|
||||||
|
def _extract_data(self, method, query):
|
||||||
|
# Fetch from yfinance
|
||||||
|
return yf.Ticker(query["symbol"]).history()
|
||||||
|
|
||||||
|
def _transform_data(self, method, raw_data, query):
|
||||||
|
# Convert to standard format
|
||||||
|
return {"ohlcv": raw_data.to_dict()}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay: float = 1.0,
|
||||||
|
retry_backoff: float = 2.0,
|
||||||
|
timeout: Optional[float] = None
|
||||||
|
):
|
||||||
|
"""Initialize vendor with retry configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_retries: Maximum number of retry attempts
|
||||||
|
retry_delay: Initial delay between retries in seconds
|
||||||
|
retry_backoff: Multiplier for delay after each retry
|
||||||
|
timeout: Optional timeout for operations in seconds
|
||||||
|
"""
|
||||||
|
self._max_retries = max_retries
|
||||||
|
self._retry_delay = retry_delay
|
||||||
|
self._retry_backoff = retry_backoff
|
||||||
|
self._timeout = timeout
|
||||||
|
self._call_count = 0
|
||||||
|
self._call_count_lock = threading.Lock()
|
||||||
|
self._last_call_time: Optional[datetime] = None
|
||||||
|
self._error_count = 0
|
||||||
|
self._success_count = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the vendor name."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def call_count(self) -> int:
|
||||||
|
"""Get the total number of calls made (thread-safe)."""
|
||||||
|
with self._call_count_lock:
|
||||||
|
return self._call_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def error_rate(self) -> float:
|
||||||
|
"""Calculate error rate as percentage."""
|
||||||
|
total = self._error_count + self._success_count
|
||||||
|
if total == 0:
|
||||||
|
return 0.0
|
||||||
|
return (self._error_count / total) * 100
|
||||||
|
|
||||||
|
def reset_stats(self) -> None:
|
||||||
|
"""Reset call statistics."""
|
||||||
|
with self._call_count_lock:
|
||||||
|
self._call_count = 0
|
||||||
|
self._error_count = 0
|
||||||
|
self._success_count = 0
|
||||||
|
self._last_call_time = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _transform_query(self, method: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
"""Transform input parameters to vendor-specific format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: The method being called
|
||||||
|
**kwargs: Input parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed query parameters
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _extract_data(self, method: str, query: Dict[str, Any]) -> Any:
|
||||||
|
"""Extract raw data from the vendor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: The method being called
|
||||||
|
query: Transformed query parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw data from the vendor
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _transform_data(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
raw_data: Any,
|
||||||
|
query: Dict[str, Any]
|
||||||
|
) -> Any:
|
||||||
|
"""Transform raw data to standardized format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: The method being called
|
||||||
|
raw_data: Raw data from extract_data
|
||||||
|
query: Original query parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed data in standard format
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _should_retry(self, exception: Exception) -> bool:
|
||||||
|
"""Determine if operation should be retried for given exception.
|
||||||
|
|
||||||
|
Override this method to customize retry logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exception: The exception that occurred
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if operation should be retried
|
||||||
|
"""
|
||||||
|
# Default: retry on network-related errors
|
||||||
|
retryable_types = (
|
||||||
|
ConnectionError,
|
||||||
|
TimeoutError,
|
||||||
|
OSError,
|
||||||
|
)
|
||||||
|
return isinstance(exception, retryable_types)
|
||||||
|
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
**kwargs
|
||||||
|
) -> VendorResponse:
|
||||||
|
"""Execute a vendor method with full pipeline.
|
||||||
|
|
||||||
|
Runs the 3-stage pipeline with retry logic:
|
||||||
|
1. transform_query
|
||||||
|
2. extract_data (with retries)
|
||||||
|
3. transform_data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: Name of the method to execute
|
||||||
|
**kwargs: Parameters for the method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VendorResponse with result or error information
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Increment call count (thread-safe)
|
||||||
|
with self._call_count_lock:
|
||||||
|
self._call_count += 1
|
||||||
|
self._last_call_time = datetime.now()
|
||||||
|
|
||||||
|
response = VendorResponse(
|
||||||
|
vendor_name=self.name,
|
||||||
|
method_name=method
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Stage 1: Transform query
|
||||||
|
query = self._transform_query(method, **kwargs)
|
||||||
|
logger.debug(f"[{self.name}] Transformed query for {method}: {query}")
|
||||||
|
|
||||||
|
# Stage 2: Extract data with retries
|
||||||
|
raw_data = self._extract_with_retry(method, query)
|
||||||
|
|
||||||
|
# Stage 3: Transform data
|
||||||
|
transformed = self._transform_data(method, raw_data, query)
|
||||||
|
|
||||||
|
response.data = transformed
|
||||||
|
response.success = True
|
||||||
|
self._success_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
response.success = False
|
||||||
|
response.error_message = str(e)
|
||||||
|
self._error_count += 1
|
||||||
|
logger.error(f"[{self.name}] Error executing {method}: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
response.execution_time_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _extract_with_retry(self, method: str, query: Dict[str, Any]) -> Any:
|
||||||
|
"""Execute extract_data with retry logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: Method name
|
||||||
|
query: Query parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw data from vendor
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Last exception if all retries fail
|
||||||
|
"""
|
||||||
|
last_exception: Optional[Exception] = None
|
||||||
|
delay = self._retry_delay
|
||||||
|
|
||||||
|
for attempt in range(self._max_retries + 1):
|
||||||
|
try:
|
||||||
|
return self._extract_data(method, query)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
if not self._should_retry(e):
|
||||||
|
logger.debug(f"[{self.name}] Non-retryable error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if attempt < self._max_retries:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.name}] Retry {attempt + 1}/{self._max_retries} "
|
||||||
|
f"after error: {e}. Waiting {delay:.1f}s"
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
delay *= self._retry_backoff
|
||||||
|
|
||||||
|
# All retries exhausted
|
||||||
|
raise last_exception # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleVendor(BaseVendor):
|
||||||
|
"""Simplified vendor for wrapping existing functions.
|
||||||
|
|
||||||
|
Useful for migrating existing vendor implementations to the new pattern
|
||||||
|
without major refactoring.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
vendor = SimpleVendor(
|
||||||
|
vendor_name="yfinance",
|
||||||
|
methods={
|
||||||
|
"get_stock_data": get_yfinance_stock,
|
||||||
|
"get_fundamentals": get_yfinance_fundamentals,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vendor_name: str,
|
||||||
|
methods: Dict[str, callable],
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""Initialize simple vendor with existing functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vendor_name: Name of the vendor
|
||||||
|
methods: Dictionary mapping method names to callables
|
||||||
|
**kwargs: Additional BaseVendor configuration
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._vendor_name = vendor_name
|
||||||
|
self._methods = methods
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the vendor name."""
|
||||||
|
return self._vendor_name
|
||||||
|
|
||||||
|
def _transform_query(self, method: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
"""Pass through query parameters unchanged."""
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def _extract_data(self, method: str, query: Dict[str, Any]) -> Any:
|
||||||
|
"""Call the wrapped function."""
|
||||||
|
if method not in self._methods:
|
||||||
|
raise ValueError(f"Method '{method}' not found in vendor '{self.name}'")
|
||||||
|
|
||||||
|
func = self._methods[method]
|
||||||
|
return func(**query)
|
||||||
|
|
||||||
|
def _transform_data(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
raw_data: Any,
|
||||||
|
query: Dict[str, Any]
|
||||||
|
) -> Any:
|
||||||
|
"""Return data unchanged."""
|
||||||
|
return raw_data
|
||||||
|
|
||||||
|
def add_method(self, method_name: str, func: callable) -> None:
|
||||||
|
"""Add a method to the vendor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method_name: Name of the method
|
||||||
|
func: Callable to execute
|
||||||
|
"""
|
||||||
|
self._methods[method_name] = func
|
||||||
|
|
||||||
|
def get_methods(self) -> List[str]:
|
||||||
|
"""Get list of available methods."""
|
||||||
|
return list(self._methods.keys())
|
||||||
|
|
@ -0,0 +1,351 @@
|
||||||
|
"""Decorators for vendor registration and method management.
|
||||||
|
|
||||||
|
Provides convenient decorators for registering vendors and methods with
|
||||||
|
the global registry.
|
||||||
|
|
||||||
|
Issue #11: [DATA-10] Interface routing - add new data vendors
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Callable, Optional, Set, Type
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .vendor_registry import (
|
||||||
|
VendorCapability,
|
||||||
|
VendorMetadata,
|
||||||
|
VendorRegistry,
|
||||||
|
get_registry
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def register_vendor(
|
||||||
|
name: str,
|
||||||
|
priority: int = 100,
|
||||||
|
capabilities: Optional[Set[VendorCapability]] = None,
|
||||||
|
rate_limit_exception: Optional[Type[Exception]] = None,
|
||||||
|
description: str = ""
|
||||||
|
) -> Callable:
|
||||||
|
"""Class decorator to register a vendor with the global registry.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@register_vendor(
|
||||||
|
name="yfinance",
|
||||||
|
priority=10,
|
||||||
|
capabilities={VendorCapability.STOCK_DATA, VendorCapability.FUNDAMENTALS}
|
||||||
|
)
|
||||||
|
class YFinanceVendor(BaseVendor):
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
def decorator(cls: Type) -> Type:
|
||||||
|
metadata = VendorMetadata(
|
||||||
|
name=name,
|
||||||
|
priority=priority,
|
||||||
|
capabilities=capabilities or set(),
|
||||||
|
rate_limit_exception=rate_limit_exception,
|
||||||
|
description=description
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register vendor on class definition
|
||||||
|
get_registry().register_vendor(metadata)
|
||||||
|
|
||||||
|
# Store metadata on class for reference
|
||||||
|
cls._vendor_metadata = metadata
|
||||||
|
cls._vendor_name = name
|
||||||
|
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def vendor_method(
|
||||||
|
method_name: str,
|
||||||
|
vendor_name: Optional[str] = None
|
||||||
|
) -> Callable:
|
||||||
|
"""Decorator to register a method with the vendor registry.
|
||||||
|
|
||||||
|
Can be used as a method decorator on vendor classes or as a
|
||||||
|
standalone function decorator.
|
||||||
|
|
||||||
|
Example (on class method):
|
||||||
|
class YFinanceVendor(BaseVendor):
|
||||||
|
@vendor_method("get_stock_data")
|
||||||
|
def get_stock(self, ticker):
|
||||||
|
pass
|
||||||
|
|
||||||
|
Example (standalone):
|
||||||
|
@vendor_method("get_stock_data", vendor_name="yfinance")
|
||||||
|
def get_yfinance_stock(ticker):
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Store registration info for lazy registration
|
||||||
|
wrapper._vendor_method_name = method_name
|
||||||
|
wrapper._vendor_name = vendor_name
|
||||||
|
|
||||||
|
# If vendor_name provided, register immediately
|
||||||
|
if vendor_name:
|
||||||
|
try:
|
||||||
|
get_registry().register_method(vendor_name, method_name, wrapper)
|
||||||
|
except ValueError:
|
||||||
|
# Vendor not yet registered - will be registered later
|
||||||
|
logger.debug(
|
||||||
|
f"Deferred registration of {method_name} for {vendor_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""Thread-safe rate limiter with sliding window.
|
||||||
|
|
||||||
|
Tracks request timestamps and enforces rate limits using a sliding window.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_calls: int, period_seconds: float):
|
||||||
|
"""Initialize rate limiter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_calls: Maximum calls allowed in the period
|
||||||
|
period_seconds: Time window in seconds
|
||||||
|
"""
|
||||||
|
self._max_calls = max_calls
|
||||||
|
self._period = period_seconds
|
||||||
|
self._calls: list = []
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def acquire(self) -> bool:
|
||||||
|
"""Try to acquire a rate limit slot.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if request is allowed, False if rate limited
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
now = time.time()
|
||||||
|
cutoff = now - self._period
|
||||||
|
|
||||||
|
# Remove expired entries
|
||||||
|
self._calls = [t for t in self._calls if t > cutoff]
|
||||||
|
|
||||||
|
# Check if under limit
|
||||||
|
if len(self._calls) < self._max_calls:
|
||||||
|
self._calls.append(now)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def wait_time(self) -> float:
|
||||||
|
"""Get time to wait before next allowed request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Seconds to wait (0 if request can proceed)
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
now = time.time()
|
||||||
|
cutoff = now - self._period
|
||||||
|
|
||||||
|
# Remove expired entries
|
||||||
|
self._calls = [t for t in self._calls if t > cutoff]
|
||||||
|
|
||||||
|
if len(self._calls) < self._max_calls:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Time until oldest call expires
|
||||||
|
return self._calls[0] + self._period - now
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the rate limiter."""
|
||||||
|
with self._lock:
|
||||||
|
self._calls.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# Global rate limiters for vendors
|
||||||
|
_rate_limiters: dict = {}
|
||||||
|
_rate_limiter_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rate_limiter(
|
||||||
|
vendor_name: str,
|
||||||
|
max_calls: int,
|
||||||
|
period_seconds: float
|
||||||
|
) -> RateLimiter:
|
||||||
|
"""Get or create a rate limiter for a vendor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vendor_name: Vendor identifier
|
||||||
|
max_calls: Max calls per period
|
||||||
|
period_seconds: Rate limit period
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RateLimiter instance
|
||||||
|
"""
|
||||||
|
with _rate_limiter_lock:
|
||||||
|
key = f"{vendor_name}:{max_calls}:{period_seconds}"
|
||||||
|
if key not in _rate_limiters:
|
||||||
|
_rate_limiters[key] = RateLimiter(max_calls, period_seconds)
|
||||||
|
return _rate_limiters[key]
|
||||||
|
|
||||||
|
|
||||||
|
def rate_limited(
|
||||||
|
max_calls: int,
|
||||||
|
period_seconds: float = 60.0,
|
||||||
|
vendor_name: Optional[str] = None,
|
||||||
|
exception_class: Optional[Type[Exception]] = None
|
||||||
|
) -> Callable:
|
||||||
|
"""Decorator to apply rate limiting to a function.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@rate_limited(max_calls=5, period_seconds=60, vendor_name="alpha_vantage")
|
||||||
|
def get_stock_data(ticker):
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
# Determine vendor name from function or provided value
|
||||||
|
limiter_name = vendor_name or getattr(func, '_vendor_name', func.__name__)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
limiter = get_rate_limiter(limiter_name, max_calls, period_seconds)
|
||||||
|
|
||||||
|
if not limiter.acquire():
|
||||||
|
wait_seconds = limiter.wait_time()
|
||||||
|
error_msg = (
|
||||||
|
f"Rate limit exceeded for {limiter_name}. "
|
||||||
|
f"Try again in {wait_seconds:.1f} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
if exception_class:
|
||||||
|
raise exception_class(error_msg)
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def with_retry(
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay: float = 1.0,
|
||||||
|
backoff_multiplier: float = 2.0,
|
||||||
|
retryable_exceptions: tuple = (ConnectionError, TimeoutError)
|
||||||
|
) -> Callable:
|
||||||
|
"""Decorator to add retry logic to a function.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@with_retry(max_retries=3, retry_delay=1.0)
|
||||||
|
def fetch_data():
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
last_exception = None
|
||||||
|
delay = retry_delay
|
||||||
|
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except retryable_exceptions as e:
|
||||||
|
last_exception = e
|
||||||
|
if attempt < max_retries:
|
||||||
|
logger.warning(
|
||||||
|
f"Retry {attempt + 1}/{max_retries} for {func.__name__}: {e}"
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
delay *= backoff_multiplier
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
raise last_exception # type: ignore
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def cache_result(
|
||||||
|
ttl_seconds: float = 300.0,
|
||||||
|
max_size: int = 100
|
||||||
|
) -> Callable:
|
||||||
|
"""Decorator to cache function results.
|
||||||
|
|
||||||
|
Simple TTL-based cache with LRU eviction.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@cache_result(ttl_seconds=60)
|
||||||
|
def get_stock_data(ticker):
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
cache: dict = {}
|
||||||
|
cache_lock = threading.Lock()
|
||||||
|
access_order: list = []
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Create cache key from args
|
||||||
|
key = (args, tuple(sorted(kwargs.items())))
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
with cache_lock:
|
||||||
|
# Check cache
|
||||||
|
if key in cache:
|
||||||
|
value, timestamp = cache[key]
|
||||||
|
if now - timestamp < ttl_seconds:
|
||||||
|
# Update access order
|
||||||
|
if key in access_order:
|
||||||
|
access_order.remove(key)
|
||||||
|
access_order.append(key)
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
# Expired
|
||||||
|
del cache[key]
|
||||||
|
if key in access_order:
|
||||||
|
access_order.remove(key)
|
||||||
|
|
||||||
|
# Execute function
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
|
with cache_lock:
|
||||||
|
# Evict oldest if at capacity
|
||||||
|
while len(cache) >= max_size and access_order:
|
||||||
|
oldest_key = access_order.pop(0)
|
||||||
|
cache.pop(oldest_key, None)
|
||||||
|
|
||||||
|
# Store result
|
||||||
|
cache[key] = (result, now)
|
||||||
|
access_order.append(key)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Add cache control methods
|
||||||
|
def clear_cache():
|
||||||
|
with cache_lock:
|
||||||
|
cache.clear()
|
||||||
|
access_order.clear()
|
||||||
|
|
||||||
|
def cache_info():
|
||||||
|
with cache_lock:
|
||||||
|
return {
|
||||||
|
"size": len(cache),
|
||||||
|
"max_size": max_size,
|
||||||
|
"ttl_seconds": ttl_seconds
|
||||||
|
}
|
||||||
|
|
||||||
|
wrapper.clear_cache = clear_cache
|
||||||
|
wrapper.cache_info = cache_info
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
@ -0,0 +1,329 @@
|
||||||
|
"""Vendor Registry for extensible data vendor management.
|
||||||
|
|
||||||
|
This module provides a thread-safe registry pattern for managing data vendors,
|
||||||
|
enabling easy addition of new vendors without modifying core interface code.
|
||||||
|
|
||||||
|
Issue #11: [DATA-10] Interface routing - add new data vendors
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Callable, Dict, List, Optional, Set, Any, Type
|
||||||
|
import threading
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VendorCapability(Enum):
|
||||||
|
"""Capabilities that vendors can provide."""
|
||||||
|
STOCK_DATA = auto()
|
||||||
|
TECHNICAL_INDICATORS = auto()
|
||||||
|
FUNDAMENTALS = auto()
|
||||||
|
BALANCE_SHEET = auto()
|
||||||
|
CASHFLOW = auto()
|
||||||
|
INCOME_STATEMENT = auto()
|
||||||
|
NEWS = auto()
|
||||||
|
GLOBAL_NEWS = auto()
|
||||||
|
INSIDER_SENTIMENT = auto()
|
||||||
|
INSIDER_TRANSACTIONS = auto()
|
||||||
|
MACROECONOMIC = auto()
|
||||||
|
BENCHMARK = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VendorMetadata:
|
||||||
|
"""Metadata about a registered vendor."""
|
||||||
|
name: str
|
||||||
|
priority: int = 100 # Lower = higher priority
|
||||||
|
capabilities: Set[VendorCapability] = field(default_factory=set)
|
||||||
|
rate_limit_exception: Optional[Type[Exception]] = None
|
||||||
|
description: str = ""
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.name)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, VendorMetadata):
|
||||||
|
return False
|
||||||
|
return self.name == other.name
|
||||||
|
|
||||||
|
|
||||||
|
class VendorRegistry:
|
||||||
|
"""Thread-safe singleton registry for data vendors.
|
||||||
|
|
||||||
|
Provides centralized vendor registration, lookup, and routing.
|
||||||
|
Uses double-checked locking for thread-safe singleton pattern.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Register a vendor
|
||||||
|
registry = VendorRegistry()
|
||||||
|
registry.register_vendor(
|
||||||
|
VendorMetadata(
|
||||||
|
name="yfinance",
|
||||||
|
priority=10,
|
||||||
|
capabilities={VendorCapability.STOCK_DATA, VendorCapability.FUNDAMENTALS}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register a method
|
||||||
|
registry.register_method("yfinance", "get_stock_data", get_yfinance_stock)
|
||||||
|
|
||||||
|
# Get vendors for a capability
|
||||||
|
vendors = registry.get_vendors_for_capability(VendorCapability.STOCK_DATA)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: Optional["VendorRegistry"] = None
|
||||||
|
_lock: threading.Lock = threading.Lock()
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
|
def __new__(cls) -> "VendorRegistry":
|
||||||
|
"""Thread-safe singleton instantiation with double-checked locking."""
|
||||||
|
# Fast path: instance exists
|
||||||
|
if cls._instance is not None:
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
# Slow path: acquire lock and create instance
|
||||||
|
with cls._lock:
|
||||||
|
# Double-check inside lock
|
||||||
|
if cls._instance is None:
|
||||||
|
instance = super().__new__(cls)
|
||||||
|
# Initialize instance attributes before publishing
|
||||||
|
instance._vendors: Dict[str, VendorMetadata] = {}
|
||||||
|
instance._methods: Dict[str, Dict[str, Callable]] = {}
|
||||||
|
instance._capability_index: Dict[VendorCapability, Set[str]] = {}
|
||||||
|
instance._vendor_lock = threading.RLock()
|
||||||
|
# Publish instance only after fully initialized
|
||||||
|
cls._instance = instance
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize registry (only runs once due to singleton)."""
|
||||||
|
# Skip if already initialized
|
||||||
|
if VendorRegistry._initialized:
|
||||||
|
return
|
||||||
|
VendorRegistry._initialized = True
|
||||||
|
|
||||||
|
def register_vendor(self, metadata: VendorMetadata) -> None:
|
||||||
|
"""Register a new vendor with its metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: VendorMetadata containing vendor info and capabilities
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If vendor name is empty
|
||||||
|
"""
|
||||||
|
if not metadata.name:
|
||||||
|
raise ValueError("Vendor name cannot be empty")
|
||||||
|
|
||||||
|
with self._vendor_lock:
|
||||||
|
self._vendors[metadata.name] = metadata
|
||||||
|
|
||||||
|
# Update capability index
|
||||||
|
for capability in metadata.capabilities:
|
||||||
|
if capability not in self._capability_index:
|
||||||
|
self._capability_index[capability] = set()
|
||||||
|
self._capability_index[capability].add(metadata.name)
|
||||||
|
|
||||||
|
logger.debug(f"Registered vendor: {metadata.name} with capabilities: {metadata.capabilities}")
|
||||||
|
|
||||||
|
def unregister_vendor(self, name: str) -> bool:
|
||||||
|
"""Unregister a vendor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of vendor to unregister
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if vendor was unregistered, False if not found
|
||||||
|
"""
|
||||||
|
with self._vendor_lock:
|
||||||
|
if name not in self._vendors:
|
||||||
|
return False
|
||||||
|
|
||||||
|
metadata = self._vendors.pop(name)
|
||||||
|
|
||||||
|
# Remove from capability index
|
||||||
|
for capability in metadata.capabilities:
|
||||||
|
if capability in self._capability_index:
|
||||||
|
self._capability_index[capability].discard(name)
|
||||||
|
|
||||||
|
# Remove all registered methods
|
||||||
|
if name in self._methods:
|
||||||
|
del self._methods[name]
|
||||||
|
|
||||||
|
logger.debug(f"Unregistered vendor: {name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def register_method(
|
||||||
|
self,
|
||||||
|
vendor_name: str,
|
||||||
|
method_name: str,
|
||||||
|
implementation: Callable
|
||||||
|
) -> None:
|
||||||
|
"""Register a method implementation for a vendor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vendor_name: Name of the vendor
|
||||||
|
method_name: Name of the method
|
||||||
|
implementation: Callable implementation
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If vendor is not registered or method name is empty
|
||||||
|
"""
|
||||||
|
if not method_name:
|
||||||
|
raise ValueError("Method name cannot be empty")
|
||||||
|
|
||||||
|
with self._vendor_lock:
|
||||||
|
if vendor_name not in self._vendors:
|
||||||
|
raise ValueError(f"Vendor '{vendor_name}' not registered")
|
||||||
|
|
||||||
|
if vendor_name not in self._methods:
|
||||||
|
self._methods[vendor_name] = {}
|
||||||
|
|
||||||
|
self._methods[vendor_name][method_name] = implementation
|
||||||
|
logger.debug(f"Registered method '{method_name}' for vendor '{vendor_name}'")
|
||||||
|
|
||||||
|
def get_vendor(self, name: str) -> Optional[VendorMetadata]:
|
||||||
|
"""Get vendor metadata by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Vendor name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VendorMetadata if found, None otherwise
|
||||||
|
"""
|
||||||
|
with self._vendor_lock:
|
||||||
|
return self._vendors.get(name)
|
||||||
|
|
||||||
|
def get_all_vendors(self) -> List[VendorMetadata]:
|
||||||
|
"""Get all registered vendors sorted by priority.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of VendorMetadata sorted by priority (lower = higher)
|
||||||
|
"""
|
||||||
|
with self._vendor_lock:
|
||||||
|
return sorted(
|
||||||
|
self._vendors.values(),
|
||||||
|
key=lambda v: v.priority
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_vendors_for_capability(
|
||||||
|
self,
|
||||||
|
capability: VendorCapability,
|
||||||
|
only_enabled: bool = True
|
||||||
|
) -> List[VendorMetadata]:
|
||||||
|
"""Get all vendors that support a specific capability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
capability: The capability to find vendors for
|
||||||
|
only_enabled: If True, only return enabled vendors
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of VendorMetadata sorted by priority
|
||||||
|
"""
|
||||||
|
with self._vendor_lock:
|
||||||
|
vendor_names = self._capability_index.get(capability, set())
|
||||||
|
vendors = [
|
||||||
|
self._vendors[name]
|
||||||
|
for name in vendor_names
|
||||||
|
if name in self._vendors
|
||||||
|
]
|
||||||
|
|
||||||
|
if only_enabled:
|
||||||
|
vendors = [v for v in vendors if v.enabled]
|
||||||
|
|
||||||
|
return sorted(vendors, key=lambda v: v.priority)
|
||||||
|
|
||||||
|
def get_method(
|
||||||
|
self,
|
||||||
|
vendor_name: str,
|
||||||
|
method_name: str
|
||||||
|
) -> Optional[Callable]:
|
||||||
|
"""Get a method implementation for a vendor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vendor_name: Name of the vendor
|
||||||
|
method_name: Name of the method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable if found, None otherwise
|
||||||
|
"""
|
||||||
|
with self._vendor_lock:
|
||||||
|
vendor_methods = self._methods.get(vendor_name, {})
|
||||||
|
return vendor_methods.get(method_name)
|
||||||
|
|
||||||
|
def get_methods_for_vendor(self, vendor_name: str) -> Dict[str, Callable]:
|
||||||
|
"""Get all methods registered for a vendor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vendor_name: Name of the vendor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping method names to implementations
|
||||||
|
"""
|
||||||
|
with self._vendor_lock:
|
||||||
|
return dict(self._methods.get(vendor_name, {}))
|
||||||
|
|
||||||
|
def set_vendor_enabled(self, name: str, enabled: bool) -> bool:
|
||||||
|
"""Enable or disable a vendor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Vendor name
|
||||||
|
enabled: Whether vendor should be enabled
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if vendor was found and updated, False otherwise
|
||||||
|
"""
|
||||||
|
with self._vendor_lock:
|
||||||
|
if name not in self._vendors:
|
||||||
|
return False
|
||||||
|
self._vendors[name].enabled = enabled
|
||||||
|
return True
|
||||||
|
|
||||||
|
def set_vendor_priority(self, name: str, priority: int) -> bool:
|
||||||
|
"""Update a vendor's priority.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Vendor name
|
||||||
|
priority: New priority (lower = higher priority)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if vendor was found and updated, False otherwise
|
||||||
|
"""
|
||||||
|
with self._vendor_lock:
|
||||||
|
if name not in self._vendors:
|
||||||
|
return False
|
||||||
|
self._vendors[name].priority = priority
|
||||||
|
return True
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all registrations. Primarily for testing."""
|
||||||
|
with self._vendor_lock:
|
||||||
|
self._vendors.clear()
|
||||||
|
self._methods.clear()
|
||||||
|
self._capability_index.clear()
|
||||||
|
logger.debug("Cleared all vendor registrations")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_instance(cls) -> None:
|
||||||
|
"""Reset the singleton instance. For testing only."""
|
||||||
|
with cls._lock:
|
||||||
|
cls._instance = None
|
||||||
|
cls._initialized = False
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level convenience functions
|
||||||
|
def get_registry() -> VendorRegistry:
|
||||||
|
"""Get the global vendor registry instance."""
|
||||||
|
return VendorRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
def register_vendor(metadata: VendorMetadata) -> None:
|
||||||
|
"""Register a vendor in the global registry."""
|
||||||
|
get_registry().register_vendor(metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def register_method(vendor_name: str, method_name: str, implementation: Callable) -> None:
|
||||||
|
"""Register a method in the global registry."""
|
||||||
|
get_registry().register_method(vendor_name, method_name, implementation)
|
||||||
Loading…
Reference in New Issue