feat: add database-backed caching to dataflows interface
- Add MarketDataService for database-backed market data caching - Integrate cache lookup/write into route_to_vendor function - Support configurable TTL per data type (stock data: 1h, fundamentals: 24h+) - Make caching opt-in via database_enabled and db_cache_enabled config flags 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
1db81e1fc6
commit
fb1a66f5a6
|
|
@ -1,6 +1,12 @@
|
|||
from .base import Base
|
||||
from .engine import get_db_session, get_engine, init_database, reset_engine
|
||||
from .services import AnalysisService, DiscoveryService, TradingService
|
||||
from .services import (
|
||||
AnalysisService,
|
||||
DiscoveryService,
|
||||
MarketDataService,
|
||||
TradingService,
|
||||
get_default_ttl,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
|
|
@ -10,5 +16,7 @@ __all__ = [
|
|||
"reset_engine",
|
||||
"AnalysisService",
|
||||
"DiscoveryService",
|
||||
"MarketDataService",
|
||||
"TradingService",
|
||||
"get_default_ttl",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
from .analysis import AnalysisService
|
||||
from .discovery import DiscoveryService
|
||||
from .market_data import MarketDataService, get_default_ttl
|
||||
from .trading import TradingService
|
||||
|
||||
__all__ = [
|
||||
"AnalysisService",
|
||||
"DiscoveryService",
|
||||
"MarketDataService",
|
||||
"TradingService",
|
||||
"get_default_ttl",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,139 @@
|
|||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from tradingagents.database.models import DataCache
|
||||
from tradingagents.database.repositories import DataCacheRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarketDataService:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
self.cache_repo = DataCacheRepository(session)
|
||||
|
||||
def _generate_cache_key(
|
||||
self,
|
||||
method: str,
|
||||
ticker: str | None = None,
|
||||
date: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
key_parts = [method]
|
||||
if ticker:
|
||||
key_parts.append(ticker)
|
||||
if date:
|
||||
key_parts.append(date)
|
||||
for k, v in sorted(kwargs.items()):
|
||||
key_parts.append(f"{k}={v}")
|
||||
key_string = ":".join(key_parts)
|
||||
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
|
||||
|
||||
def get_cached_data(
|
||||
self,
|
||||
method: str,
|
||||
ticker: str | None = None,
|
||||
date: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any | None:
|
||||
cache_key = self._generate_cache_key(method, ticker, date, **kwargs)
|
||||
cache_entry = self.cache_repo.get_valid_cache(cache_key)
|
||||
|
||||
if cache_entry and cache_entry.cached_data:
|
||||
logger.debug("Cache hit for %s (key: %s)", method, cache_key[:8])
|
||||
try:
|
||||
return json.loads(cache_entry.cached_data)
|
||||
except json.JSONDecodeError:
|
||||
return cache_entry.cached_data
|
||||
|
||||
logger.debug("Cache miss for %s (key: %s)", method, cache_key[:8])
|
||||
return None
|
||||
|
||||
def set_cached_data(
|
||||
self,
|
||||
method: str,
|
||||
data: Any,
|
||||
ticker: str | None = None,
|
||||
date: str | None = None,
|
||||
ttl_hours: int = 24,
|
||||
**kwargs: Any,
|
||||
) -> DataCache:
|
||||
cache_key = self._generate_cache_key(method, ticker, date, **kwargs)
|
||||
expires_at = datetime.utcnow() + timedelta(hours=ttl_hours)
|
||||
|
||||
cached_data = data if isinstance(data, str) else json.dumps(data)
|
||||
|
||||
logger.debug(
|
||||
"Caching data for %s (key: %s, ttl: %dh)",
|
||||
method,
|
||||
cache_key[:8],
|
||||
ttl_hours,
|
||||
)
|
||||
|
||||
return self.cache_repo.set_cache(
|
||||
cache_key=cache_key,
|
||||
data_type=method,
|
||||
cached_data=cached_data,
|
||||
expires_at=expires_at,
|
||||
ticker=ticker,
|
||||
)
|
||||
|
||||
def get_or_fetch(
|
||||
self,
|
||||
method: str,
|
||||
fetch_func: callable,
|
||||
ticker: str | None = None,
|
||||
date: str | None = None,
|
||||
ttl_hours: int = 24,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
cached = self.get_cached_data(method, ticker, date, **kwargs)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
logger.debug("Fetching fresh data for %s", method)
|
||||
result = fetch_func()
|
||||
|
||||
if result:
|
||||
self.set_cached_data(method, result, ticker, date, ttl_hours, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
def clear_expired_cache(self) -> int:
|
||||
count = self.cache_repo.clear_expired()
|
||||
logger.info("Cleared %d expired cache entries", count)
|
||||
return count
|
||||
|
||||
def invalidate_ticker_cache(self, ticker: str) -> int:
|
||||
entries = self.session.query(DataCache).filter(DataCache.ticker == ticker).all()
|
||||
count = 0
|
||||
for entry in entries:
|
||||
self.session.delete(entry)
|
||||
count += 1
|
||||
self.session.flush()
|
||||
logger.info("Invalidated %d cache entries for ticker %s", count, ticker)
|
||||
return count
|
||||
|
||||
|
||||
DEFAULT_TTL_HOURS = {
|
||||
"get_stock_data": 1,
|
||||
"get_indicators": 1,
|
||||
"get_fundamentals": 24,
|
||||
"get_balance_sheet": 168,
|
||||
"get_cashflow": 168,
|
||||
"get_income_statement": 168,
|
||||
"get_news": 1,
|
||||
"get_global_news": 1,
|
||||
"get_insider_sentiment": 24,
|
||||
"get_insider_transactions": 24,
|
||||
"get_bulk_news": 1,
|
||||
}
|
||||
|
||||
|
||||
def get_default_ttl(method: str) -> int:
|
||||
return DEFAULT_TTL_HOURS.get(method, 24)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.agents.discovery import NewsArticle
|
||||
|
||||
|
|
@ -139,6 +139,52 @@ _bulk_news_cache: dict[str, dict[str, Any]] = {}
|
|||
_bulk_news_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
def _is_db_cache_enabled() -> bool:
|
||||
config = get_config()
|
||||
return config.get("database_enabled", False) and config.get(
|
||||
"db_cache_enabled", True
|
||||
)
|
||||
|
||||
|
||||
def _get_db_cached_data(method: str, ticker: str | None = None, **kwargs: Any) -> Any:
|
||||
if not _is_db_cache_enabled():
|
||||
return None
|
||||
|
||||
try:
|
||||
from tradingagents.database import MarketDataService, get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
service = MarketDataService(session)
|
||||
return service.get_cached_data(method, ticker, **kwargs)
|
||||
except (ImportError, RuntimeError, ConnectionError) as e:
|
||||
logger.debug("DB cache lookup failed: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _set_db_cached_data(
|
||||
method: str,
|
||||
data: Any,
|
||||
ticker: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if not _is_db_cache_enabled():
|
||||
return
|
||||
|
||||
try:
|
||||
from tradingagents.database import (
|
||||
MarketDataService,
|
||||
get_db_session,
|
||||
get_default_ttl,
|
||||
)
|
||||
|
||||
ttl_hours = get_default_ttl(method)
|
||||
with get_db_session() as session:
|
||||
service = MarketDataService(session)
|
||||
service.set_cached_data(method, data, ticker, ttl_hours=ttl_hours, **kwargs)
|
||||
except (ImportError, RuntimeError, ConnectionError) as e:
|
||||
logger.debug("DB cache write failed: %s", e)
|
||||
|
||||
|
||||
def parse_lookback_period(lookback: str) -> int:
|
||||
lookback = lookback.lower().strip()
|
||||
|
||||
|
|
@ -277,6 +323,14 @@ def get_vendor(category: str, method: str = None) -> str:
|
|||
|
||||
|
||||
def route_to_vendor(method: str, *args, **kwargs):
|
||||
ticker = kwargs.get("ticker") or (args[0] if args else None)
|
||||
date = kwargs.get("date") or (args[1] if len(args) > 1 else None)
|
||||
|
||||
cached_result = _get_db_cached_data(method, ticker, date=date)
|
||||
if cached_result is not None:
|
||||
logger.info("DB cache hit for %s (ticker=%s)", method, ticker)
|
||||
return cached_result
|
||||
|
||||
category = get_category_for_method(method)
|
||||
vendor_config = get_vendor(category, method)
|
||||
|
||||
|
|
@ -407,6 +461,10 @@ def route_to_vendor(method: str, *args, **kwargs):
|
|||
)
|
||||
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
final_result = results[0]
|
||||
else:
|
||||
return "\n".join(str(result) for result in results)
|
||||
final_result = "\n".join(str(result) for result in results)
|
||||
|
||||
_set_db_cached_data(method, final_result, ticker, date=date)
|
||||
|
||||
return final_result
|
||||
|
|
|
|||
Loading…
Reference in New Issue