352 lines
9.8 KiB
Python
352 lines
9.8 KiB
Python
"""Decorators for vendor registration and method management.
|
|
|
|
Provides convenient decorators for registering vendors and methods with
|
|
the global registry.
|
|
|
|
Issue #11: [DATA-10] Interface routing - add new data vendors
|
|
"""
|
|
|
|
from functools import wraps
|
|
from typing import Callable, Optional, Set, Type
|
|
import threading
|
|
import time
|
|
import logging
|
|
|
|
from .vendor_registry import (
|
|
VendorCapability,
|
|
VendorMetadata,
|
|
VendorRegistry,
|
|
get_registry
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def register_vendor(
|
|
name: str,
|
|
priority: int = 100,
|
|
capabilities: Optional[Set[VendorCapability]] = None,
|
|
rate_limit_exception: Optional[Type[Exception]] = None,
|
|
description: str = ""
|
|
) -> Callable:
|
|
"""Class decorator to register a vendor with the global registry.
|
|
|
|
Example:
|
|
@register_vendor(
|
|
name="yfinance",
|
|
priority=10,
|
|
capabilities={VendorCapability.STOCK_DATA, VendorCapability.FUNDAMENTALS}
|
|
)
|
|
class YFinanceVendor(BaseVendor):
|
|
pass
|
|
"""
|
|
def decorator(cls: Type) -> Type:
|
|
metadata = VendorMetadata(
|
|
name=name,
|
|
priority=priority,
|
|
capabilities=capabilities or set(),
|
|
rate_limit_exception=rate_limit_exception,
|
|
description=description
|
|
)
|
|
|
|
# Register vendor on class definition
|
|
get_registry().register_vendor(metadata)
|
|
|
|
# Store metadata on class for reference
|
|
cls._vendor_metadata = metadata
|
|
cls._vendor_name = name
|
|
|
|
return cls
|
|
|
|
return decorator
|
|
|
|
|
|
def vendor_method(
|
|
method_name: str,
|
|
vendor_name: Optional[str] = None
|
|
) -> Callable:
|
|
"""Decorator to register a method with the vendor registry.
|
|
|
|
Can be used as a method decorator on vendor classes or as a
|
|
standalone function decorator.
|
|
|
|
Example (on class method):
|
|
class YFinanceVendor(BaseVendor):
|
|
@vendor_method("get_stock_data")
|
|
def get_stock(self, ticker):
|
|
pass
|
|
|
|
Example (standalone):
|
|
@vendor_method("get_stock_data", vendor_name="yfinance")
|
|
def get_yfinance_stock(ticker):
|
|
pass
|
|
"""
|
|
def decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
# Store registration info for lazy registration
|
|
wrapper._vendor_method_name = method_name
|
|
wrapper._vendor_name = vendor_name
|
|
|
|
# If vendor_name provided, register immediately
|
|
if vendor_name:
|
|
try:
|
|
get_registry().register_method(vendor_name, method_name, wrapper)
|
|
except ValueError:
|
|
# Vendor not yet registered - will be registered later
|
|
logger.debug(
|
|
f"Deferred registration of {method_name} for {vendor_name}"
|
|
)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
class RateLimiter:
|
|
"""Thread-safe rate limiter with sliding window.
|
|
|
|
Tracks request timestamps and enforces rate limits using a sliding window.
|
|
"""
|
|
|
|
def __init__(self, max_calls: int, period_seconds: float):
|
|
"""Initialize rate limiter.
|
|
|
|
Args:
|
|
max_calls: Maximum calls allowed in the period
|
|
period_seconds: Time window in seconds
|
|
"""
|
|
self._max_calls = max_calls
|
|
self._period = period_seconds
|
|
self._calls: list = []
|
|
self._lock = threading.Lock()
|
|
|
|
def acquire(self) -> bool:
|
|
"""Try to acquire a rate limit slot.
|
|
|
|
Returns:
|
|
True if request is allowed, False if rate limited
|
|
"""
|
|
with self._lock:
|
|
now = time.time()
|
|
cutoff = now - self._period
|
|
|
|
# Remove expired entries
|
|
self._calls = [t for t in self._calls if t > cutoff]
|
|
|
|
# Check if under limit
|
|
if len(self._calls) < self._max_calls:
|
|
self._calls.append(now)
|
|
return True
|
|
|
|
return False
|
|
|
|
def wait_time(self) -> float:
|
|
"""Get time to wait before next allowed request.
|
|
|
|
Returns:
|
|
Seconds to wait (0 if request can proceed)
|
|
"""
|
|
with self._lock:
|
|
now = time.time()
|
|
cutoff = now - self._period
|
|
|
|
# Remove expired entries
|
|
self._calls = [t for t in self._calls if t > cutoff]
|
|
|
|
if len(self._calls) < self._max_calls:
|
|
return 0.0
|
|
|
|
# Time until oldest call expires
|
|
return self._calls[0] + self._period - now
|
|
|
|
def reset(self) -> None:
|
|
"""Reset the rate limiter."""
|
|
with self._lock:
|
|
self._calls.clear()
|
|
|
|
|
|
# Global rate limiters for vendors
|
|
_rate_limiters: dict = {}
|
|
_rate_limiter_lock = threading.Lock()
|
|
|
|
|
|
def get_rate_limiter(
|
|
vendor_name: str,
|
|
max_calls: int,
|
|
period_seconds: float
|
|
) -> RateLimiter:
|
|
"""Get or create a rate limiter for a vendor.
|
|
|
|
Args:
|
|
vendor_name: Vendor identifier
|
|
max_calls: Max calls per period
|
|
period_seconds: Rate limit period
|
|
|
|
Returns:
|
|
RateLimiter instance
|
|
"""
|
|
with _rate_limiter_lock:
|
|
key = f"{vendor_name}:{max_calls}:{period_seconds}"
|
|
if key not in _rate_limiters:
|
|
_rate_limiters[key] = RateLimiter(max_calls, period_seconds)
|
|
return _rate_limiters[key]
|
|
|
|
|
|
def rate_limited(
|
|
max_calls: int,
|
|
period_seconds: float = 60.0,
|
|
vendor_name: Optional[str] = None,
|
|
exception_class: Optional[Type[Exception]] = None
|
|
) -> Callable:
|
|
"""Decorator to apply rate limiting to a function.
|
|
|
|
Example:
|
|
@rate_limited(max_calls=5, period_seconds=60, vendor_name="alpha_vantage")
|
|
def get_stock_data(ticker):
|
|
pass
|
|
"""
|
|
def decorator(func: Callable) -> Callable:
|
|
# Determine vendor name from function or provided value
|
|
limiter_name = vendor_name or getattr(func, '_vendor_name', func.__name__)
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
limiter = get_rate_limiter(limiter_name, max_calls, period_seconds)
|
|
|
|
if not limiter.acquire():
|
|
wait_seconds = limiter.wait_time()
|
|
error_msg = (
|
|
f"Rate limit exceeded for {limiter_name}. "
|
|
f"Try again in {wait_seconds:.1f} seconds"
|
|
)
|
|
|
|
if exception_class:
|
|
raise exception_class(error_msg)
|
|
raise RuntimeError(error_msg)
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def with_retry(
|
|
max_retries: int = 3,
|
|
retry_delay: float = 1.0,
|
|
backoff_multiplier: float = 2.0,
|
|
retryable_exceptions: tuple = (ConnectionError, TimeoutError)
|
|
) -> Callable:
|
|
"""Decorator to add retry logic to a function.
|
|
|
|
Example:
|
|
@with_retry(max_retries=3, retry_delay=1.0)
|
|
def fetch_data():
|
|
pass
|
|
"""
|
|
def decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
last_exception = None
|
|
delay = retry_delay
|
|
|
|
for attempt in range(max_retries + 1):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except retryable_exceptions as e:
|
|
last_exception = e
|
|
if attempt < max_retries:
|
|
logger.warning(
|
|
f"Retry {attempt + 1}/{max_retries} for {func.__name__}: {e}"
|
|
)
|
|
time.sleep(delay)
|
|
delay *= backoff_multiplier
|
|
else:
|
|
raise
|
|
|
|
raise last_exception # type: ignore
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def cache_result(
|
|
ttl_seconds: float = 300.0,
|
|
max_size: int = 100
|
|
) -> Callable:
|
|
"""Decorator to cache function results.
|
|
|
|
Simple TTL-based cache with LRU eviction.
|
|
|
|
Example:
|
|
@cache_result(ttl_seconds=60)
|
|
def get_stock_data(ticker):
|
|
pass
|
|
"""
|
|
def decorator(func: Callable) -> Callable:
|
|
cache: dict = {}
|
|
cache_lock = threading.Lock()
|
|
access_order: list = []
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
# Create cache key from args
|
|
key = (args, tuple(sorted(kwargs.items())))
|
|
now = time.time()
|
|
|
|
with cache_lock:
|
|
# Check cache
|
|
if key in cache:
|
|
value, timestamp = cache[key]
|
|
if now - timestamp < ttl_seconds:
|
|
# Update access order
|
|
if key in access_order:
|
|
access_order.remove(key)
|
|
access_order.append(key)
|
|
return value
|
|
else:
|
|
# Expired
|
|
del cache[key]
|
|
if key in access_order:
|
|
access_order.remove(key)
|
|
|
|
# Execute function
|
|
result = func(*args, **kwargs)
|
|
|
|
with cache_lock:
|
|
# Evict oldest if at capacity
|
|
while len(cache) >= max_size and access_order:
|
|
oldest_key = access_order.pop(0)
|
|
cache.pop(oldest_key, None)
|
|
|
|
# Store result
|
|
cache[key] = (result, now)
|
|
access_order.append(key)
|
|
|
|
return result
|
|
|
|
# Add cache control methods
|
|
def clear_cache():
|
|
with cache_lock:
|
|
cache.clear()
|
|
access_order.clear()
|
|
|
|
def cache_info():
|
|
with cache_lock:
|
|
return {
|
|
"size": len(cache),
|
|
"max_size": max_size,
|
|
"ttl_seconds": ttl_seconds
|
|
}
|
|
|
|
wrapper.clear_cache = clear_cache
|
|
wrapper.cache_info = cache_info
|
|
|
|
return wrapper
|
|
|
|
return decorator
|