140 lines
3.9 KiB
Python
140 lines
3.9 KiB
Python
import hashlib
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Any
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from tradingagents.database.models import DataCache
|
|
from tradingagents.database.repositories import DataCacheRepository
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MarketDataService:
|
|
def __init__(self, session: Session):
|
|
self.session = session
|
|
self.cache_repo = DataCacheRepository(session)
|
|
|
|
def _generate_cache_key(
|
|
self,
|
|
method: str,
|
|
ticker: str | None = None,
|
|
date: str | None = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
key_parts = [method]
|
|
if ticker:
|
|
key_parts.append(ticker)
|
|
if date:
|
|
key_parts.append(date)
|
|
for k, v in sorted(kwargs.items()):
|
|
key_parts.append(f"{k}={v}")
|
|
key_string = ":".join(key_parts)
|
|
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
|
|
|
|
def get_cached_data(
|
|
self,
|
|
method: str,
|
|
ticker: str | None = None,
|
|
date: str | None = None,
|
|
**kwargs: Any,
|
|
) -> Any | None:
|
|
cache_key = self._generate_cache_key(method, ticker, date, **kwargs)
|
|
cache_entry = self.cache_repo.get_valid_cache(cache_key)
|
|
|
|
if cache_entry and cache_entry.cached_data:
|
|
logger.debug("Cache hit for %s (key: %s)", method, cache_key[:8])
|
|
try:
|
|
return json.loads(cache_entry.cached_data)
|
|
except json.JSONDecodeError:
|
|
return cache_entry.cached_data
|
|
|
|
logger.debug("Cache miss for %s (key: %s)", method, cache_key[:8])
|
|
return None
|
|
|
|
def set_cached_data(
|
|
self,
|
|
method: str,
|
|
data: Any,
|
|
ticker: str | None = None,
|
|
date: str | None = None,
|
|
ttl_hours: int = 24,
|
|
**kwargs: Any,
|
|
) -> DataCache:
|
|
cache_key = self._generate_cache_key(method, ticker, date, **kwargs)
|
|
expires_at = datetime.utcnow() + timedelta(hours=ttl_hours)
|
|
|
|
cached_data = data if isinstance(data, str) else json.dumps(data)
|
|
|
|
logger.debug(
|
|
"Caching data for %s (key: %s, ttl: %dh)",
|
|
method,
|
|
cache_key[:8],
|
|
ttl_hours,
|
|
)
|
|
|
|
return self.cache_repo.set_cache(
|
|
cache_key=cache_key,
|
|
data_type=method,
|
|
cached_data=cached_data,
|
|
expires_at=expires_at,
|
|
ticker=ticker,
|
|
)
|
|
|
|
def get_or_fetch(
|
|
self,
|
|
method: str,
|
|
fetch_func: callable,
|
|
ticker: str | None = None,
|
|
date: str | None = None,
|
|
ttl_hours: int = 24,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
cached = self.get_cached_data(method, ticker, date, **kwargs)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
logger.debug("Fetching fresh data for %s", method)
|
|
result = fetch_func()
|
|
|
|
if result:
|
|
self.set_cached_data(method, result, ticker, date, ttl_hours, **kwargs)
|
|
|
|
return result
|
|
|
|
def clear_expired_cache(self) -> int:
|
|
count = self.cache_repo.clear_expired()
|
|
logger.info("Cleared %d expired cache entries", count)
|
|
return count
|
|
|
|
def invalidate_ticker_cache(self, ticker: str) -> int:
|
|
entries = self.session.query(DataCache).filter(DataCache.ticker == ticker).all()
|
|
count = 0
|
|
for entry in entries:
|
|
self.session.delete(entry)
|
|
count += 1
|
|
self.session.flush()
|
|
logger.info("Invalidated %d cache entries for ticker %s", count, ticker)
|
|
return count
|
|
|
|
|
|
DEFAULT_TTL_HOURS = {
|
|
"get_stock_data": 1,
|
|
"get_indicators": 1,
|
|
"get_fundamentals": 24,
|
|
"get_balance_sheet": 168,
|
|
"get_cashflow": 168,
|
|
"get_income_statement": 168,
|
|
"get_news": 1,
|
|
"get_global_news": 1,
|
|
"get_insider_sentiment": 24,
|
|
"get_insider_transactions": 24,
|
|
"get_bulk_news": 1,
|
|
}
|
|
|
|
|
|
def get_default_ttl(method: str) -> int:
|
|
return DEFAULT_TTL_HOURS.get(method, 24)
|