feat: add database-backed caching to dataflows interface

- Add MarketDataService for database-backed market data caching
- Integrate cache lookup/write into route_to_vendor function
- Support configurable TTL per data type (stock data: 1h, fundamentals: 24h+)
- Make caching opt-in via database_enabled and db_cache_enabled config flags

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-03 11:45:46 -05:00
parent 1db81e1fc6
commit fb1a66f5a6
4 changed files with 212 additions and 4 deletions

View File

@ -1,6 +1,12 @@
from .base import Base
from .engine import get_db_session, get_engine, init_database, reset_engine
from .services import AnalysisService, DiscoveryService, TradingService
from .services import (
AnalysisService,
DiscoveryService,
MarketDataService,
TradingService,
get_default_ttl,
)
__all__ = [
"Base",
@ -10,5 +16,7 @@ __all__ = [
"reset_engine",
"AnalysisService",
"DiscoveryService",
"MarketDataService",
"TradingService",
"get_default_ttl",
]

View File

@ -1,9 +1,12 @@
from .analysis import AnalysisService
from .discovery import DiscoveryService
from .market_data import MarketDataService, get_default_ttl
from .trading import TradingService
__all__ = [
"AnalysisService",
"DiscoveryService",
"MarketDataService",
"TradingService",
"get_default_ttl",
]

View File

@ -0,0 +1,139 @@
import hashlib
import json
import logging
from datetime import datetime, timedelta
from typing import Any
from sqlalchemy.orm import Session
from tradingagents.database.models import DataCache
from tradingagents.database.repositories import DataCacheRepository
logger = logging.getLogger(__name__)
class MarketDataService:
def __init__(self, session: Session):
self.session = session
self.cache_repo = DataCacheRepository(session)
def _generate_cache_key(
self,
method: str,
ticker: str | None = None,
date: str | None = None,
**kwargs: Any,
) -> str:
key_parts = [method]
if ticker:
key_parts.append(ticker)
if date:
key_parts.append(date)
for k, v in sorted(kwargs.items()):
key_parts.append(f"{k}={v}")
key_string = ":".join(key_parts)
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
def get_cached_data(
self,
method: str,
ticker: str | None = None,
date: str | None = None,
**kwargs: Any,
) -> Any | None:
cache_key = self._generate_cache_key(method, ticker, date, **kwargs)
cache_entry = self.cache_repo.get_valid_cache(cache_key)
if cache_entry and cache_entry.cached_data:
logger.debug("Cache hit for %s (key: %s)", method, cache_key[:8])
try:
return json.loads(cache_entry.cached_data)
except json.JSONDecodeError:
return cache_entry.cached_data
logger.debug("Cache miss for %s (key: %s)", method, cache_key[:8])
return None
def set_cached_data(
self,
method: str,
data: Any,
ticker: str | None = None,
date: str | None = None,
ttl_hours: int = 24,
**kwargs: Any,
) -> DataCache:
cache_key = self._generate_cache_key(method, ticker, date, **kwargs)
expires_at = datetime.utcnow() + timedelta(hours=ttl_hours)
cached_data = data if isinstance(data, str) else json.dumps(data)
logger.debug(
"Caching data for %s (key: %s, ttl: %dh)",
method,
cache_key[:8],
ttl_hours,
)
return self.cache_repo.set_cache(
cache_key=cache_key,
data_type=method,
cached_data=cached_data,
expires_at=expires_at,
ticker=ticker,
)
def get_or_fetch(
self,
method: str,
fetch_func: callable,
ticker: str | None = None,
date: str | None = None,
ttl_hours: int = 24,
**kwargs: Any,
) -> Any:
cached = self.get_cached_data(method, ticker, date, **kwargs)
if cached is not None:
return cached
logger.debug("Fetching fresh data for %s", method)
result = fetch_func()
if result:
self.set_cached_data(method, result, ticker, date, ttl_hours, **kwargs)
return result
def clear_expired_cache(self) -> int:
count = self.cache_repo.clear_expired()
logger.info("Cleared %d expired cache entries", count)
return count
def invalidate_ticker_cache(self, ticker: str) -> int:
entries = self.session.query(DataCache).filter(DataCache.ticker == ticker).all()
count = 0
for entry in entries:
self.session.delete(entry)
count += 1
self.session.flush()
logger.info("Invalidated %d cache entries for ticker %s", count, ticker)
return count
DEFAULT_TTL_HOURS = {
"get_stock_data": 1,
"get_indicators": 1,
"get_fundamentals": 24,
"get_balance_sheet": 168,
"get_cashflow": 168,
"get_income_statement": 168,
"get_news": 1,
"get_global_news": 1,
"get_insider_sentiment": 24,
"get_insider_transactions": 24,
"get_bulk_news": 1,
}
def get_default_ttl(method: str) -> int:
return DEFAULT_TTL_HOURS.get(method, 24)

View File

@ -1,7 +1,7 @@
import logging
import threading
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any
from tradingagents.agents.discovery import NewsArticle
@ -139,6 +139,52 @@ _bulk_news_cache: dict[str, dict[str, Any]] = {}
_bulk_news_cache_lock = threading.Lock()
def _is_db_cache_enabled() -> bool:
config = get_config()
return config.get("database_enabled", False) and config.get(
"db_cache_enabled", True
)
def _get_db_cached_data(method: str, ticker: str | None = None, **kwargs: Any) -> Any:
if not _is_db_cache_enabled():
return None
try:
from tradingagents.database import MarketDataService, get_db_session
with get_db_session() as session:
service = MarketDataService(session)
return service.get_cached_data(method, ticker, **kwargs)
except (ImportError, RuntimeError, ConnectionError) as e:
logger.debug("DB cache lookup failed: %s", e)
return None
def _set_db_cached_data(
method: str,
data: Any,
ticker: str | None = None,
**kwargs: Any,
) -> None:
if not _is_db_cache_enabled():
return
try:
from tradingagents.database import (
MarketDataService,
get_db_session,
get_default_ttl,
)
ttl_hours = get_default_ttl(method)
with get_db_session() as session:
service = MarketDataService(session)
service.set_cached_data(method, data, ticker, ttl_hours=ttl_hours, **kwargs)
except (ImportError, RuntimeError, ConnectionError) as e:
logger.debug("DB cache write failed: %s", e)
def parse_lookback_period(lookback: str) -> int:
lookback = lookback.lower().strip()
@ -277,6 +323,14 @@ def get_vendor(category: str, method: str = None) -> str:
def route_to_vendor(method: str, *args, **kwargs):
ticker = kwargs.get("ticker") or (args[0] if args else None)
date = kwargs.get("date") or (args[1] if len(args) > 1 else None)
cached_result = _get_db_cached_data(method, ticker, date=date)
if cached_result is not None:
logger.info("DB cache hit for %s (ticker=%s)", method, ticker)
return cached_result
category = get_category_for_method(method)
vendor_config = get_vendor(category, method)
@ -407,6 +461,10 @@ def route_to_vendor(method: str, *args, **kwargs):
)
if len(results) == 1:
return results[0]
final_result = results[0]
else:
return "\n".join(str(result) for result in results)
final_result = "\n".join(str(result) for result in results)
_set_db_cached_data(method, final_result, ticker, date=date)
return final_result