From fb1a66f5a67f220d8a735006ffd1a662eaf8eeb9 Mon Sep 17 00:00:00 2001 From: Joseph O'Brien <98370624+89jobrien@users.noreply.github.com> Date: Wed, 3 Dec 2025 11:45:46 -0500 Subject: [PATCH] feat: add database-backed caching to dataflows interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add MarketDataService for database-backed market data caching - Integrate cache lookup/write into route_to_vendor function - Support configurable TTL per data type (stock data: 1h, fundamentals: 24h+) - Make caching opt-in via database_enabled and db_cache_enabled config flags 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- tradingagents/database/__init__.py | 10 +- tradingagents/database/services/__init__.py | 3 + .../database/services/market_data.py | 139 ++++++++++++++++++ tradingagents/dataflows/interface.py | 64 +++++++- 4 files changed, 212 insertions(+), 4 deletions(-) create mode 100644 tradingagents/database/services/market_data.py diff --git a/tradingagents/database/__init__.py b/tradingagents/database/__init__.py index 566427af..05ce0a9c 100644 --- a/tradingagents/database/__init__.py +++ b/tradingagents/database/__init__.py @@ -1,6 +1,12 @@ from .base import Base from .engine import get_db_session, get_engine, init_database, reset_engine -from .services import AnalysisService, DiscoveryService, TradingService +from .services import ( + AnalysisService, + DiscoveryService, + MarketDataService, + TradingService, + get_default_ttl, +) __all__ = [ "Base", @@ -10,5 +16,7 @@ __all__ = [ "reset_engine", "AnalysisService", "DiscoveryService", + "MarketDataService", "TradingService", + "get_default_ttl", ] diff --git a/tradingagents/database/services/__init__.py b/tradingagents/database/services/__init__.py index 1571974c..329bbe97 100644 --- a/tradingagents/database/services/__init__.py +++ b/tradingagents/database/services/__init__.py @@ -1,9 +1,12 @@ from .analysis import AnalysisService from .discovery import DiscoveryService +from .market_data import MarketDataService, get_default_ttl from .trading import TradingService __all__ = [ "AnalysisService", "DiscoveryService", + "MarketDataService", "TradingService", + "get_default_ttl", ] diff --git a/tradingagents/database/services/market_data.py b/tradingagents/database/services/market_data.py new file mode 100644 index 00000000..7a837604 --- /dev/null +++ b/tradingagents/database/services/market_data.py @@ -0,0 +1,139 @@ +import hashlib +import json +import logging +from datetime import datetime, timedelta +from typing import Any + +from sqlalchemy.orm import Session + +from tradingagents.database.models import DataCache +from tradingagents.database.repositories import DataCacheRepository + +logger = logging.getLogger(__name__) + + +class MarketDataService: + def __init__(self, session: Session): + self.session = session + self.cache_repo = DataCacheRepository(session) + + def _generate_cache_key( + self, + method: str, + ticker: str | None = None, + date: str | None = None, + **kwargs: Any, + ) -> str: + key_parts = [method] + if ticker: + key_parts.append(ticker) + if date: + key_parts.append(date) + for k, v in sorted(kwargs.items()): + key_parts.append(f"{k}={v}") + key_string = ":".join(key_parts) + return hashlib.sha256(key_string.encode()).hexdigest()[:32] + + def get_cached_data( + self, + method: str, + ticker: str | None = None, + date: str | None = None, + **kwargs: Any, + ) -> Any | None: + cache_key = self._generate_cache_key(method, ticker, date, **kwargs) + cache_entry = self.cache_repo.get_valid_cache(cache_key) + + if cache_entry and cache_entry.cached_data: + logger.debug("Cache hit for %s (key: %s)", method, cache_key[:8]) + try: + return json.loads(cache_entry.cached_data) + except json.JSONDecodeError: + return cache_entry.cached_data + + logger.debug("Cache miss for %s (key: %s)", method, cache_key[:8]) + return None + + def set_cached_data( + self, + method: str, + data: Any, + ticker: str | None = None, + date: str | None = None, + ttl_hours: int = 24, + **kwargs: Any, + ) -> DataCache: + cache_key = self._generate_cache_key(method, ticker, date, **kwargs) + expires_at = datetime.utcnow() + timedelta(hours=ttl_hours) + + cached_data = data if isinstance(data, str) else json.dumps(data) + + logger.debug( + "Caching data for %s (key: %s, ttl: %dh)", + method, + cache_key[:8], + ttl_hours, + ) + + return self.cache_repo.set_cache( + cache_key=cache_key, + data_type=method, + cached_data=cached_data, + expires_at=expires_at, + ticker=ticker, + ) + + def get_or_fetch( + self, + method: str, + fetch_func: callable, + ticker: str | None = None, + date: str | None = None, + ttl_hours: int = 24, + **kwargs: Any, + ) -> Any: + cached = self.get_cached_data(method, ticker, date, **kwargs) + if cached is not None: + return cached + + logger.debug("Fetching fresh data for %s", method) + result = fetch_func() + + if result: + self.set_cached_data(method, result, ticker, date, ttl_hours, **kwargs) + + return result + + def clear_expired_cache(self) -> int: + count = self.cache_repo.clear_expired() + logger.info("Cleared %d expired cache entries", count) + return count + + def invalidate_ticker_cache(self, ticker: str) -> int: + entries = self.session.query(DataCache).filter(DataCache.ticker == ticker).all() + count = 0 + for entry in entries: + self.session.delete(entry) + count += 1 + self.session.flush() + logger.info("Invalidated %d cache entries for ticker %s", count, ticker) + return count + + +DEFAULT_TTL_HOURS = { + "get_stock_data": 1, + "get_indicators": 1, + "get_fundamentals": 24, + "get_balance_sheet": 168, + "get_cashflow": 168, + "get_income_statement": 168, + "get_news": 1, + "get_global_news": 1, + "get_insider_sentiment": 24, + "get_insider_transactions": 24, + "get_bulk_news": 1, +} + + +def get_default_ttl(method: str) -> int: + return DEFAULT_TTL_HOURS.get(method, 24) diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 0bdce98b..fa072850 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -1,7 +1,7 @@ import logging import threading from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any from tradingagents.agents.discovery import NewsArticle @@ -139,6 +139,52 @@ _bulk_news_cache: dict[str, dict[str, Any]] = {} _bulk_news_cache_lock = threading.Lock() +def _is_db_cache_enabled() -> bool: + config = get_config() + return config.get("database_enabled", False) and config.get( + "db_cache_enabled", True + ) + + +def _get_db_cached_data(method: str, ticker: str | None = None, **kwargs: Any) -> Any: + if not _is_db_cache_enabled(): + return None + + try: + from tradingagents.database import MarketDataService, get_db_session + + with get_db_session() as session: + service = MarketDataService(session) + return service.get_cached_data(method, ticker, **kwargs) + except (ImportError, RuntimeError, ConnectionError) as e: + logger.debug("DB cache lookup failed: %s", e) + return None + + +def _set_db_cached_data( + method: str, + data: Any, + ticker: str | None = None, + **kwargs: Any, +) -> None: + if not _is_db_cache_enabled(): + return + + try: + from tradingagents.database import ( + MarketDataService, + get_db_session, + get_default_ttl, + ) + + ttl_hours = get_default_ttl(method) + with get_db_session() as session: + service = MarketDataService(session) + service.set_cached_data(method, data, ticker, ttl_hours=ttl_hours, **kwargs) + except (ImportError, RuntimeError, ConnectionError) as e: + logger.debug("DB cache write failed: %s", e) + + def parse_lookback_period(lookback: str) -> int: lookback = lookback.lower().strip() @@ -277,6 +323,14 @@ def get_vendor(category: str, method: str = None) -> str: def route_to_vendor(method: str, *args, **kwargs): + ticker = kwargs.get("ticker") or (args[0] if args else None) + date = kwargs.get("date") or (args[1] if len(args) > 1 else None) + + cached_result = _get_db_cached_data(method, ticker, date=date) + if cached_result is not None: + logger.info("DB cache hit for %s (ticker=%s)", method, ticker) + return cached_result + category = get_category_for_method(method) vendor_config = get_vendor(category, method) @@ -407,6 +461,10 @@ def route_to_vendor(method: str, *args, **kwargs): ) if len(results) == 1: - return results[0] + final_result = results[0] else: - return "\n".join(str(result) for result in results) + final_result = "\n".join(str(result) for result in results) + + _set_db_cached_data(method, final_result, ticker, date=date) + + return final_result