feat: add database-backed caching to dataflows interface

- Add MarketDataService for database-backed market data caching
- Integrate cache lookup/write into route_to_vendor function
- Support configurable TTL per data type (stock data: 1h, fundamentals: 24h+)
- Make caching opt-in via database_enabled and db_cache_enabled config flags

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-03 11:45:46 -05:00
parent 1db81e1fc6
commit fb1a66f5a6
4 changed files with 212 additions and 4 deletions

View File

@ -1,6 +1,12 @@
from .base import Base from .base import Base
from .engine import get_db_session, get_engine, init_database, reset_engine from .engine import get_db_session, get_engine, init_database, reset_engine
from .services import AnalysisService, DiscoveryService, TradingService from .services import (
AnalysisService,
DiscoveryService,
MarketDataService,
TradingService,
get_default_ttl,
)
__all__ = [ __all__ = [
"Base", "Base",
@ -10,5 +16,7 @@ __all__ = [
"reset_engine", "reset_engine",
"AnalysisService", "AnalysisService",
"DiscoveryService", "DiscoveryService",
"MarketDataService",
"TradingService", "TradingService",
"get_default_ttl",
] ]

View File

@ -1,9 +1,12 @@
from .analysis import AnalysisService from .analysis import AnalysisService
from .discovery import DiscoveryService from .discovery import DiscoveryService
from .market_data import MarketDataService, get_default_ttl
from .trading import TradingService from .trading import TradingService
__all__ = [ __all__ = [
"AnalysisService", "AnalysisService",
"DiscoveryService", "DiscoveryService",
"MarketDataService",
"TradingService", "TradingService",
"get_default_ttl",
] ]

View File

@ -0,0 +1,139 @@
import hashlib
import json
import logging
from datetime import datetime, timedelta
from typing import Any
from sqlalchemy.orm import Session
from tradingagents.database.models import DataCache
from tradingagents.database.repositories import DataCacheRepository
logger = logging.getLogger(__name__)
class MarketDataService:
def __init__(self, session: Session):
self.session = session
self.cache_repo = DataCacheRepository(session)
def _generate_cache_key(
self,
method: str,
ticker: str | None = None,
date: str | None = None,
**kwargs: Any,
) -> str:
key_parts = [method]
if ticker:
key_parts.append(ticker)
if date:
key_parts.append(date)
for k, v in sorted(kwargs.items()):
key_parts.append(f"{k}={v}")
key_string = ":".join(key_parts)
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
def get_cached_data(
self,
method: str,
ticker: str | None = None,
date: str | None = None,
**kwargs: Any,
) -> Any | None:
cache_key = self._generate_cache_key(method, ticker, date, **kwargs)
cache_entry = self.cache_repo.get_valid_cache(cache_key)
if cache_entry and cache_entry.cached_data:
logger.debug("Cache hit for %s (key: %s)", method, cache_key[:8])
try:
return json.loads(cache_entry.cached_data)
except json.JSONDecodeError:
return cache_entry.cached_data
logger.debug("Cache miss for %s (key: %s)", method, cache_key[:8])
return None
def set_cached_data(
self,
method: str,
data: Any,
ticker: str | None = None,
date: str | None = None,
ttl_hours: int = 24,
**kwargs: Any,
) -> DataCache:
cache_key = self._generate_cache_key(method, ticker, date, **kwargs)
expires_at = datetime.utcnow() + timedelta(hours=ttl_hours)
cached_data = data if isinstance(data, str) else json.dumps(data)
logger.debug(
"Caching data for %s (key: %s, ttl: %dh)",
method,
cache_key[:8],
ttl_hours,
)
return self.cache_repo.set_cache(
cache_key=cache_key,
data_type=method,
cached_data=cached_data,
expires_at=expires_at,
ticker=ticker,
)
def get_or_fetch(
self,
method: str,
fetch_func: callable,
ticker: str | None = None,
date: str | None = None,
ttl_hours: int = 24,
**kwargs: Any,
) -> Any:
cached = self.get_cached_data(method, ticker, date, **kwargs)
if cached is not None:
return cached
logger.debug("Fetching fresh data for %s", method)
result = fetch_func()
if result:
self.set_cached_data(method, result, ticker, date, ttl_hours, **kwargs)
return result
def clear_expired_cache(self) -> int:
count = self.cache_repo.clear_expired()
logger.info("Cleared %d expired cache entries", count)
return count
def invalidate_ticker_cache(self, ticker: str) -> int:
entries = self.session.query(DataCache).filter(DataCache.ticker == ticker).all()
count = 0
for entry in entries:
self.session.delete(entry)
count += 1
self.session.flush()
logger.info("Invalidated %d cache entries for ticker %s", count, ticker)
return count
DEFAULT_TTL_HOURS = {
"get_stock_data": 1,
"get_indicators": 1,
"get_fundamentals": 24,
"get_balance_sheet": 168,
"get_cashflow": 168,
"get_income_statement": 168,
"get_news": 1,
"get_global_news": 1,
"get_insider_sentiment": 24,
"get_insider_transactions": 24,
"get_bulk_news": 1,
}
def get_default_ttl(method: str) -> int:
return DEFAULT_TTL_HOURS.get(method, 24)

View File

@ -1,7 +1,7 @@
import logging import logging
import threading import threading
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any
from tradingagents.agents.discovery import NewsArticle from tradingagents.agents.discovery import NewsArticle
@ -139,6 +139,52 @@ _bulk_news_cache: dict[str, dict[str, Any]] = {}
_bulk_news_cache_lock = threading.Lock() _bulk_news_cache_lock = threading.Lock()
def _is_db_cache_enabled() -> bool:
config = get_config()
return config.get("database_enabled", False) and config.get(
"db_cache_enabled", True
)
def _get_db_cached_data(method: str, ticker: str | None = None, **kwargs: Any) -> Any:
if not _is_db_cache_enabled():
return None
try:
from tradingagents.database import MarketDataService, get_db_session
with get_db_session() as session:
service = MarketDataService(session)
return service.get_cached_data(method, ticker, **kwargs)
except (ImportError, RuntimeError, ConnectionError) as e:
logger.debug("DB cache lookup failed: %s", e)
return None
def _set_db_cached_data(
method: str,
data: Any,
ticker: str | None = None,
**kwargs: Any,
) -> None:
if not _is_db_cache_enabled():
return
try:
from tradingagents.database import (
MarketDataService,
get_db_session,
get_default_ttl,
)
ttl_hours = get_default_ttl(method)
with get_db_session() as session:
service = MarketDataService(session)
service.set_cached_data(method, data, ticker, ttl_hours=ttl_hours, **kwargs)
except (ImportError, RuntimeError, ConnectionError) as e:
logger.debug("DB cache write failed: %s", e)
def parse_lookback_period(lookback: str) -> int: def parse_lookback_period(lookback: str) -> int:
lookback = lookback.lower().strip() lookback = lookback.lower().strip()
@ -277,6 +323,14 @@ def get_vendor(category: str, method: str = None) -> str:
def route_to_vendor(method: str, *args, **kwargs): def route_to_vendor(method: str, *args, **kwargs):
ticker = kwargs.get("ticker") or (args[0] if args else None)
date = kwargs.get("date") or (args[1] if len(args) > 1 else None)
cached_result = _get_db_cached_data(method, ticker, date=date)
if cached_result is not None:
logger.info("DB cache hit for %s (ticker=%s)", method, ticker)
return cached_result
category = get_category_for_method(method) category = get_category_for_method(method)
vendor_config = get_vendor(category, method) vendor_config = get_vendor(category, method)
@ -407,6 +461,10 @@ def route_to_vendor(method: str, *args, **kwargs):
) )
if len(results) == 1: if len(results) == 1:
return results[0] final_result = results[0]
else: else:
return "\n".join(str(result) for result in results) final_result = "\n".join(str(result) for result in results)
_set_db_cached_data(method, final_result, ticker, date=date)
return final_result