355 lines
12 KiB
Python
355 lines
12 KiB
Python
"""Tests for VendorRegistry.
|
|
|
|
Issue #11: [DATA-10] Interface routing - add new data vendors
|
|
"""
|
|
|
|
import pytest
|
|
import threading
|
|
from unittest.mock import Mock
|
|
|
|
from tradingagents.dataflows.vendor_registry import (
|
|
VendorCapability,
|
|
VendorMetadata,
|
|
VendorRegistry,
|
|
get_registry,
|
|
register_vendor,
|
|
register_method,
|
|
)
|
|
|
|
pytestmark = pytest.mark.unit
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_registry():
|
|
"""Reset the singleton registry before each test."""
|
|
VendorRegistry.reset_instance()
|
|
yield
|
|
VendorRegistry.reset_instance()
|
|
|
|
|
|
class TestVendorMetadata:
|
|
"""Tests for VendorMetadata dataclass."""
|
|
|
|
def test_default_values(self):
|
|
"""Test default values are set correctly."""
|
|
metadata = VendorMetadata(name="test")
|
|
assert metadata.name == "test"
|
|
assert metadata.priority == 100
|
|
assert metadata.capabilities == set()
|
|
assert metadata.rate_limit_exception is None
|
|
assert metadata.description == ""
|
|
assert metadata.enabled is True
|
|
|
|
def test_custom_values(self):
|
|
"""Test custom values are preserved."""
|
|
metadata = VendorMetadata(
|
|
name="yfinance",
|
|
priority=10,
|
|
capabilities={VendorCapability.STOCK_DATA},
|
|
description="Yahoo Finance vendor"
|
|
)
|
|
assert metadata.name == "yfinance"
|
|
assert metadata.priority == 10
|
|
assert VendorCapability.STOCK_DATA in metadata.capabilities
|
|
assert metadata.description == "Yahoo Finance vendor"
|
|
|
|
def test_hash(self):
|
|
"""Test VendorMetadata is hashable by name."""
|
|
m1 = VendorMetadata(name="vendor1")
|
|
m2 = VendorMetadata(name="vendor1", priority=50)
|
|
assert hash(m1) == hash(m2)
|
|
|
|
def test_equality(self):
|
|
"""Test VendorMetadata equality is by name."""
|
|
m1 = VendorMetadata(name="vendor1", priority=10)
|
|
m2 = VendorMetadata(name="vendor1", priority=50)
|
|
m3 = VendorMetadata(name="vendor2")
|
|
assert m1 == m2
|
|
assert m1 != m3
|
|
|
|
def test_equality_with_non_metadata(self):
|
|
"""Test equality with non-VendorMetadata objects."""
|
|
metadata = VendorMetadata(name="test")
|
|
assert metadata != "test"
|
|
assert metadata != {"name": "test"}
|
|
|
|
|
|
class TestVendorRegistry:
|
|
"""Tests for VendorRegistry singleton."""
|
|
|
|
def test_singleton_pattern(self):
|
|
"""Test that VendorRegistry is a singleton."""
|
|
r1 = VendorRegistry()
|
|
r2 = VendorRegistry()
|
|
assert r1 is r2
|
|
|
|
def test_get_registry_returns_singleton(self):
|
|
"""Test get_registry returns same instance."""
|
|
r1 = get_registry()
|
|
r2 = VendorRegistry()
|
|
assert r1 is r2
|
|
|
|
def test_reset_instance(self):
|
|
"""Test reset_instance creates new singleton."""
|
|
r1 = VendorRegistry()
|
|
VendorRegistry.reset_instance()
|
|
r2 = VendorRegistry()
|
|
assert r1 is not r2
|
|
|
|
def test_register_vendor(self):
|
|
"""Test registering a vendor."""
|
|
registry = VendorRegistry()
|
|
metadata = VendorMetadata(
|
|
name="yfinance",
|
|
capabilities={VendorCapability.STOCK_DATA}
|
|
)
|
|
registry.register_vendor(metadata)
|
|
|
|
result = registry.get_vendor("yfinance")
|
|
assert result is not None
|
|
assert result.name == "yfinance"
|
|
|
|
def test_register_vendor_empty_name_raises(self):
|
|
"""Test registering vendor with empty name raises ValueError."""
|
|
registry = VendorRegistry()
|
|
with pytest.raises(ValueError, match="cannot be empty"):
|
|
registry.register_vendor(VendorMetadata(name=""))
|
|
|
|
def test_unregister_vendor(self):
|
|
"""Test unregistering a vendor."""
|
|
registry = VendorRegistry()
|
|
metadata = VendorMetadata(name="test_vendor")
|
|
registry.register_vendor(metadata)
|
|
|
|
assert registry.unregister_vendor("test_vendor") is True
|
|
assert registry.get_vendor("test_vendor") is None
|
|
|
|
def test_unregister_nonexistent_vendor(self):
|
|
"""Test unregistering non-existent vendor returns False."""
|
|
registry = VendorRegistry()
|
|
assert registry.unregister_vendor("nonexistent") is False
|
|
|
|
def test_get_all_vendors_sorted_by_priority(self):
|
|
"""Test get_all_vendors returns vendors sorted by priority."""
|
|
registry = VendorRegistry()
|
|
registry.register_vendor(VendorMetadata(name="low", priority=100))
|
|
registry.register_vendor(VendorMetadata(name="high", priority=10))
|
|
registry.register_vendor(VendorMetadata(name="medium", priority=50))
|
|
|
|
vendors = registry.get_all_vendors()
|
|
assert [v.name for v in vendors] == ["high", "medium", "low"]
|
|
|
|
def test_get_vendors_for_capability(self):
|
|
"""Test getting vendors by capability."""
|
|
registry = VendorRegistry()
|
|
registry.register_vendor(VendorMetadata(
|
|
name="yfinance",
|
|
capabilities={VendorCapability.STOCK_DATA, VendorCapability.FUNDAMENTALS}
|
|
))
|
|
registry.register_vendor(VendorMetadata(
|
|
name="alpha_vantage",
|
|
capabilities={VendorCapability.STOCK_DATA}
|
|
))
|
|
|
|
stock_vendors = registry.get_vendors_for_capability(VendorCapability.STOCK_DATA)
|
|
assert len(stock_vendors) == 2
|
|
|
|
fundamental_vendors = registry.get_vendors_for_capability(VendorCapability.FUNDAMENTALS)
|
|
assert len(fundamental_vendors) == 1
|
|
assert fundamental_vendors[0].name == "yfinance"
|
|
|
|
def test_get_vendors_for_capability_only_enabled(self):
|
|
"""Test only enabled vendors are returned by default."""
|
|
registry = VendorRegistry()
|
|
registry.register_vendor(VendorMetadata(
|
|
name="enabled_vendor",
|
|
capabilities={VendorCapability.STOCK_DATA}
|
|
))
|
|
registry.register_vendor(VendorMetadata(
|
|
name="disabled_vendor",
|
|
capabilities={VendorCapability.STOCK_DATA},
|
|
enabled=False
|
|
))
|
|
|
|
vendors = registry.get_vendors_for_capability(VendorCapability.STOCK_DATA)
|
|
assert len(vendors) == 1
|
|
assert vendors[0].name == "enabled_vendor"
|
|
|
|
all_vendors = registry.get_vendors_for_capability(
|
|
VendorCapability.STOCK_DATA,
|
|
only_enabled=False
|
|
)
|
|
assert len(all_vendors) == 2
|
|
|
|
|
|
class TestVendorRegistryMethods:
|
|
"""Tests for method registration in VendorRegistry."""
|
|
|
|
def test_register_method(self):
|
|
"""Test registering a method for a vendor."""
|
|
registry = VendorRegistry()
|
|
registry.register_vendor(VendorMetadata(name="yfinance"))
|
|
|
|
mock_func = Mock()
|
|
registry.register_method("yfinance", "get_stock", mock_func)
|
|
|
|
result = registry.get_method("yfinance", "get_stock")
|
|
assert result is mock_func
|
|
|
|
def test_register_method_unregistered_vendor_raises(self):
|
|
"""Test registering method for unregistered vendor raises."""
|
|
registry = VendorRegistry()
|
|
with pytest.raises(ValueError, match="not registered"):
|
|
registry.register_method("nonexistent", "method", Mock())
|
|
|
|
def test_register_method_empty_name_raises(self):
|
|
"""Test registering method with empty name raises."""
|
|
registry = VendorRegistry()
|
|
registry.register_vendor(VendorMetadata(name="test"))
|
|
with pytest.raises(ValueError, match="cannot be empty"):
|
|
registry.register_method("test", "", Mock())
|
|
|
|
def test_get_method_nonexistent(self):
|
|
"""Test getting non-existent method returns None."""
|
|
registry = VendorRegistry()
|
|
registry.register_vendor(VendorMetadata(name="test"))
|
|
assert registry.get_method("test", "nonexistent") is None
|
|
|
|
def test_get_methods_for_vendor(self):
|
|
"""Test getting all methods for a vendor."""
|
|
registry = VendorRegistry()
|
|
registry.register_vendor(VendorMetadata(name="yfinance"))
|
|
|
|
func1 = Mock()
|
|
func2 = Mock()
|
|
registry.register_method("yfinance", "get_stock", func1)
|
|
registry.register_method("yfinance", "get_fundamentals", func2)
|
|
|
|
methods = registry.get_methods_for_vendor("yfinance")
|
|
assert len(methods) == 2
|
|
assert methods["get_stock"] is func1
|
|
assert methods["get_fundamentals"] is func2
|
|
|
|
def test_get_methods_for_unregistered_vendor(self):
|
|
"""Test getting methods for unregistered vendor returns empty dict."""
|
|
registry = VendorRegistry()
|
|
assert registry.get_methods_for_vendor("nonexistent") == {}
|
|
|
|
|
|
class TestVendorRegistryVendorControl:
|
|
"""Tests for vendor enable/disable and priority control."""
|
|
|
|
def test_set_vendor_enabled(self):
|
|
"""Test enabling/disabling a vendor."""
|
|
registry = VendorRegistry()
|
|
registry.register_vendor(VendorMetadata(name="test", enabled=True))
|
|
|
|
assert registry.set_vendor_enabled("test", False) is True
|
|
assert registry.get_vendor("test").enabled is False
|
|
|
|
assert registry.set_vendor_enabled("test", True) is True
|
|
assert registry.get_vendor("test").enabled is True
|
|
|
|
def test_set_vendor_enabled_nonexistent(self):
|
|
"""Test setting enabled on non-existent vendor."""
|
|
registry = VendorRegistry()
|
|
assert registry.set_vendor_enabled("nonexistent", True) is False
|
|
|
|
def test_set_vendor_priority(self):
|
|
"""Test updating vendor priority."""
|
|
registry = VendorRegistry()
|
|
registry.register_vendor(VendorMetadata(name="test", priority=100))
|
|
|
|
assert registry.set_vendor_priority("test", 10) is True
|
|
assert registry.get_vendor("test").priority == 10
|
|
|
|
def test_set_vendor_priority_nonexistent(self):
|
|
"""Test setting priority on non-existent vendor."""
|
|
registry = VendorRegistry()
|
|
assert registry.set_vendor_priority("nonexistent", 10) is False
|
|
|
|
def test_clear(self):
|
|
"""Test clearing all registrations."""
|
|
registry = VendorRegistry()
|
|
registry.register_vendor(VendorMetadata(
|
|
name="test",
|
|
capabilities={VendorCapability.STOCK_DATA}
|
|
))
|
|
registry.register_method("test", "method", Mock())
|
|
|
|
registry.clear()
|
|
|
|
assert registry.get_vendor("test") is None
|
|
assert len(registry.get_all_vendors()) == 0
|
|
|
|
|
|
class TestVendorRegistryThreadSafety:
|
|
"""Tests for thread safety of VendorRegistry."""
|
|
|
|
def test_concurrent_registration(self):
|
|
"""Test concurrent vendor registration is thread-safe."""
|
|
registry = VendorRegistry()
|
|
errors = []
|
|
|
|
def register_vendor(i):
|
|
try:
|
|
registry.register_vendor(VendorMetadata(
|
|
name=f"vendor_{i}",
|
|
capabilities={VendorCapability.STOCK_DATA}
|
|
))
|
|
except Exception as e:
|
|
errors.append(e)
|
|
|
|
threads = [threading.Thread(target=register_vendor, args=(i,)) for i in range(50)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
assert len(errors) == 0
|
|
assert len(registry.get_all_vendors()) == 50
|
|
|
|
def test_concurrent_method_registration(self):
|
|
"""Test concurrent method registration is thread-safe."""
|
|
registry = VendorRegistry()
|
|
registry.register_vendor(VendorMetadata(name="test"))
|
|
errors = []
|
|
|
|
def register_method(i):
|
|
try:
|
|
registry.register_method("test", f"method_{i}", Mock())
|
|
except Exception as e:
|
|
errors.append(e)
|
|
|
|
threads = [threading.Thread(target=register_method, args=(i,)) for i in range(50)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
assert len(errors) == 0
|
|
assert len(registry.get_methods_for_vendor("test")) == 50
|
|
|
|
|
|
class TestModuleFunctions:
|
|
"""Tests for module-level convenience functions."""
|
|
|
|
def test_register_vendor_function(self):
|
|
"""Test module-level register_vendor function."""
|
|
metadata = VendorMetadata(name="module_test")
|
|
register_vendor(metadata)
|
|
|
|
registry = get_registry()
|
|
assert registry.get_vendor("module_test") is not None
|
|
|
|
def test_register_method_function(self):
|
|
"""Test module-level register_method function."""
|
|
metadata = VendorMetadata(name="method_test")
|
|
register_vendor(metadata)
|
|
|
|
mock_func = Mock()
|
|
register_method("method_test", "test_method", mock_func)
|
|
|
|
registry = get_registry()
|
|
assert registry.get_method("method_test", "test_method") is mock_func
|