TradingAgents/tradingagents/dataflows/market_registry.py

187 lines
6.1 KiB
Python

"""
Market provider registry for extensible multi-market support.
Each market (VN, TH, JP, etc.) implements MarketProvider and registers
itself via the MarketRegistry. The routing layer in interface.py checks
the registry before falling back to the default US vendors.
"""
from abc import ABC, abstractmethod
import json
import os
import time
from typing import Optional
class MarketProvider(ABC):
"""Base class for market-specific data providers."""
market_code: str = ""
market_name: str = ""
currency: str = ""
@abstractmethod
def get_listed_tickers(self) -> set:
"""Return all valid ticker symbols for this market."""
...
@abstractmethod
def get_stock_data(self, symbol: str, start_date: str, end_date: str) -> str:
...
@abstractmethod
def get_indicators(self, symbol: str, indicator: str, curr_date: str, look_back_days: int) -> str:
...
@abstractmethod
def get_fundamentals(self, ticker: str, curr_date: str = None) -> str:
...
@abstractmethod
def get_balance_sheet(self, ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
...
@abstractmethod
def get_cashflow(self, ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
...
@abstractmethod
def get_income_statement(self, ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
...
@abstractmethod
def get_insider_transactions(self, ticker: str) -> str:
...
@abstractmethod
def get_news(self, ticker: str, start_date: str, end_date: str) -> str:
...
@abstractmethod
def get_global_news(self, curr_date: str, look_back_days: int = 7, limit: int = 10) -> str:
...
def get_market_context(self) -> str:
"""Return market-specific context string for agent prompts."""
return ""
# Method name -> MarketProvider method mapping
MARKET_METHOD_MAP = {
"get_stock_data": "get_stock_data",
"get_indicators": "get_indicators",
"get_fundamentals": "get_fundamentals",
"get_balance_sheet": "get_balance_sheet",
"get_cashflow": "get_cashflow",
"get_income_statement": "get_income_statement",
"get_insider_transactions": "get_insider_transactions",
"get_news": "get_news",
"get_global_news": "get_global_news",
}
class MarketRegistry:
"""Registry of all available market providers with ticker auto-detection."""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._providers = {}
cls._instance._ticker_to_market = {}
cls._instance._initialized = False
return cls._instance
def register(self, provider: MarketProvider):
"""Register a market provider."""
self._providers[provider.market_code] = provider
def _ensure_tickers_loaded(self):
"""Lazy-load ticker lists from all registered providers."""
if self._initialized:
return
from .config import get_config
config = get_config()
cache_dir = config.get("data_cache_dir", "data")
os.makedirs(cache_dir, exist_ok=True)
for code, provider in self._providers.items():
cache_file = os.path.join(cache_dir, f"market_tickers_{code}.json")
tickers = None
# Try loading from cache (24h TTL)
if os.path.exists(cache_file):
try:
with open(cache_file, "r") as f:
cached = json.load(f)
if time.time() - cached.get("timestamp", 0) < 86400:
tickers = set(cached["tickers"])
except (json.JSONDecodeError, KeyError):
pass
# Fetch fresh if cache miss or expired
if tickers is None:
try:
tickers = provider.get_listed_tickers()
with open(cache_file, "w") as f:
json.dump({
"timestamp": time.time(),
"market": code,
"tickers": list(tickers),
}, f)
except Exception as e:
print(f"Warning: Could not load tickers for market {code}: {e}")
tickers = set()
for ticker in tickers:
self._ticker_to_market[ticker.upper()] = code
self._initialized = True
def detect_market(self, ticker: str) -> Optional[str]:
"""Detect which market a ticker belongs to. Returns market code or None."""
if not ticker:
return None
# Check config for explicit market override
from .config import get_config
config = get_config()
market = config.get("market", "auto")
if market not in ("auto", "US") and market in self._providers:
return market
self._ensure_tickers_loaded()
return self._ticker_to_market.get(ticker.upper())
def get_provider(self, market_code: str) -> Optional[MarketProvider]:
"""Get a market provider by its code."""
return self._providers.get(market_code)
def get_all_providers(self) -> dict:
"""Return all registered providers."""
return dict(self._providers)
def call_market_method(self, market_code: str, method: str, *args, **kwargs):
"""Call a method on the appropriate market provider."""
provider = self.get_provider(market_code)
if provider is None:
raise ValueError(f"No provider registered for market '{market_code}'")
provider_method_name = MARKET_METHOD_MAP.get(method)
if provider_method_name is None:
raise ValueError(f"Method '{method}' not mapped to any market provider method")
provider_method = getattr(provider, provider_method_name)
return provider_method(*args, **kwargs)
@classmethod
def reset(cls):
"""Reset the singleton (useful for testing)."""
cls._instance = None
# Global registry instance
registry = MarketRegistry()