TradingAgents/tradingagents/dataflows/search_provider_factory.py

133 lines
4.9 KiB
Python

from .search_provider import SearchProvider
import hashlib
import json
from typing import Dict, Callable, Any
from abc import ABC, abstractmethod
class ProviderSelector(ABC):
"""Abstract base class for provider selection strategies."""
@abstractmethod
def select_provider_type(self, config: Dict[str, Any]) -> str:
"""Select provider type based on configuration."""
pass
class MappingBasedProviderSelector(ProviderSelector):
"""Selects provider based on URL pattern mapping table."""
def __init__(self, mappings: Dict[str, str], default_provider: str = "openai"):
self._mappings = mappings
self._default_provider = default_provider
def select_provider_type(self, config: Dict[str, Any]) -> str:
backend_url = config.get("backend_url", "")
for pattern, provider_type in self._mappings.items():
if pattern in backend_url:
return provider_type
return self._default_provider
class SearchProviderRegistry:
"""Registry for search provider creation functions."""
def __init__(self):
self._providers: Dict[str, Callable[[Dict[str, Any]], SearchProvider]] = {}
def register(self, provider_type: str, creator: Callable[[Dict[str, Any]], SearchProvider]):
"""Register a provider creator function."""
self._providers[provider_type] = creator
def create(self, provider_type: str, config: Dict[str, Any]) -> SearchProvider:
"""Create a provider instance using registered creator."""
if provider_type not in self._providers:
raise ValueError(f"Unknown provider type: {provider_type}")
return self._providers[provider_type](config)
def get_available_types(self) -> list[str]:
"""Get list of available provider types."""
return list(self._providers.keys())
class SearchProviderFactoryImpl:
"""Enhanced factory for creating SearchProvider instances with caching and extensibility."""
def __init__(self, registry: SearchProviderRegistry, selector: ProviderSelector):
self._registry = registry
self._selector = selector
self._cache: Dict[str, SearchProvider] = {}
def create_provider(self, config: Dict[str, Any]) -> SearchProvider:
"""
Create a SearchProvider with caching to avoid creating new instances.
Uses config hash as cache key for efficient reuse.
"""
# Create cache key from relevant config values
cache_key_data = {
"backend_url": config.get("backend_url", ""),
"model": config.get("quick_think_llm", "")
}
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
# Return cached instance if exists
if cache_key in self._cache:
return self._cache[cache_key]
# Select and create provider
provider_type = self._selector.select_provider_type(config)
provider = self._registry.create(provider_type, config)
# Cache and return
self._cache[cache_key] = provider
return provider
def clear_cache(self):
"""Clear the provider cache (useful for testing or config changes)."""
self._cache.clear()
def get_available_provider_types(self) -> list[str]:
"""Get list of available provider types."""
return self._registry.get_available_types()
def create_search_provider_factory() -> SearchProviderFactoryImpl:
"""Create a configured SearchProviderFactory with default providers."""
registry = SearchProviderRegistry()
# Register default providers
def create_google_provider(config: Dict[str, Any]) -> SearchProvider:
from .search_provider import GoogleSearchProvider
return GoogleSearchProvider(config["quick_think_llm"])
def create_openai_provider(config: Dict[str, Any]) -> SearchProvider:
from .search_provider import OpenAISearchProvider
return OpenAISearchProvider(config["quick_think_llm"], config["backend_url"])
registry.register("google", create_google_provider)
registry.register("openai", create_openai_provider)
# Create URL pattern mappings (easily extensible)
url_mappings = {
"generativelanguage.googleapis.com": "google",
"api.openai.com": "openai",
}
selector = MappingBasedProviderSelector(url_mappings, default_provider="openai")
return SearchProviderFactoryImpl(registry, selector)
# Backward compatibility - singleton instance
_default_factory = create_search_provider_factory()
class SearchProviderFactory:
"""Backward compatibility wrapper for the old static factory."""
@staticmethod
def create_provider(config: Dict[str, Any]) -> SearchProvider:
return _default_factory.create_provider(config)
@staticmethod
def clear_cache():
_default_factory.clear_cache()