feat: add database-backed caching to dataflows interface
- Add MarketDataService for database-backed market data caching - Integrate cache lookup/write into route_to_vendor function - Support configurable TTL per data type (stock data: 1h, fundamentals: 24h+) - Make caching opt-in via database_enabled and db_cache_enabled config flags 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
1db81e1fc6
commit
fb1a66f5a6
|
|
@ -1,6 +1,12 @@
|
||||||
from .base import Base
|
from .base import Base
|
||||||
from .engine import get_db_session, get_engine, init_database, reset_engine
|
from .engine import get_db_session, get_engine, init_database, reset_engine
|
||||||
from .services import AnalysisService, DiscoveryService, TradingService
|
from .services import (
|
||||||
|
AnalysisService,
|
||||||
|
DiscoveryService,
|
||||||
|
MarketDataService,
|
||||||
|
TradingService,
|
||||||
|
get_default_ttl,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Base",
|
"Base",
|
||||||
|
|
@ -10,5 +16,7 @@ __all__ = [
|
||||||
"reset_engine",
|
"reset_engine",
|
||||||
"AnalysisService",
|
"AnalysisService",
|
||||||
"DiscoveryService",
|
"DiscoveryService",
|
||||||
|
"MarketDataService",
|
||||||
"TradingService",
|
"TradingService",
|
||||||
|
"get_default_ttl",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
from .analysis import AnalysisService
|
from .analysis import AnalysisService
|
||||||
from .discovery import DiscoveryService
|
from .discovery import DiscoveryService
|
||||||
|
from .market_data import MarketDataService, get_default_ttl
|
||||||
from .trading import TradingService
|
from .trading import TradingService
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnalysisService",
|
"AnalysisService",
|
||||||
"DiscoveryService",
|
"DiscoveryService",
|
||||||
|
"MarketDataService",
|
||||||
"TradingService",
|
"TradingService",
|
||||||
|
"get_default_ttl",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,139 @@
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from tradingagents.database.models import DataCache
|
||||||
|
from tradingagents.database.repositories import DataCacheRepository
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MarketDataService:
|
||||||
|
def __init__(self, session: Session):
|
||||||
|
self.session = session
|
||||||
|
self.cache_repo = DataCacheRepository(session)
|
||||||
|
|
||||||
|
def _generate_cache_key(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
ticker: str | None = None,
|
||||||
|
date: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
key_parts = [method]
|
||||||
|
if ticker:
|
||||||
|
key_parts.append(ticker)
|
||||||
|
if date:
|
||||||
|
key_parts.append(date)
|
||||||
|
for k, v in sorted(kwargs.items()):
|
||||||
|
key_parts.append(f"{k}={v}")
|
||||||
|
key_string = ":".join(key_parts)
|
||||||
|
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
|
||||||
|
|
||||||
|
def get_cached_data(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
ticker: str | None = None,
|
||||||
|
date: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any | None:
|
||||||
|
cache_key = self._generate_cache_key(method, ticker, date, **kwargs)
|
||||||
|
cache_entry = self.cache_repo.get_valid_cache(cache_key)
|
||||||
|
|
||||||
|
if cache_entry and cache_entry.cached_data:
|
||||||
|
logger.debug("Cache hit for %s (key: %s)", method, cache_key[:8])
|
||||||
|
try:
|
||||||
|
return json.loads(cache_entry.cached_data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return cache_entry.cached_data
|
||||||
|
|
||||||
|
logger.debug("Cache miss for %s (key: %s)", method, cache_key[:8])
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_cached_data(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
data: Any,
|
||||||
|
ticker: str | None = None,
|
||||||
|
date: str | None = None,
|
||||||
|
ttl_hours: int = 24,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> DataCache:
|
||||||
|
cache_key = self._generate_cache_key(method, ticker, date, **kwargs)
|
||||||
|
expires_at = datetime.utcnow() + timedelta(hours=ttl_hours)
|
||||||
|
|
||||||
|
cached_data = data if isinstance(data, str) else json.dumps(data)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Caching data for %s (key: %s, ttl: %dh)",
|
||||||
|
method,
|
||||||
|
cache_key[:8],
|
||||||
|
ttl_hours,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.cache_repo.set_cache(
|
||||||
|
cache_key=cache_key,
|
||||||
|
data_type=method,
|
||||||
|
cached_data=cached_data,
|
||||||
|
expires_at=expires_at,
|
||||||
|
ticker=ticker,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_or_fetch(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
fetch_func: callable,
|
||||||
|
ticker: str | None = None,
|
||||||
|
date: str | None = None,
|
||||||
|
ttl_hours: int = 24,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
cached = self.get_cached_data(method, ticker, date, **kwargs)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
logger.debug("Fetching fresh data for %s", method)
|
||||||
|
result = fetch_func()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
self.set_cached_data(method, result, ticker, date, ttl_hours, **kwargs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def clear_expired_cache(self) -> int:
|
||||||
|
count = self.cache_repo.clear_expired()
|
||||||
|
logger.info("Cleared %d expired cache entries", count)
|
||||||
|
return count
|
||||||
|
|
||||||
|
def invalidate_ticker_cache(self, ticker: str) -> int:
|
||||||
|
entries = self.session.query(DataCache).filter(DataCache.ticker == ticker).all()
|
||||||
|
count = 0
|
||||||
|
for entry in entries:
|
||||||
|
self.session.delete(entry)
|
||||||
|
count += 1
|
||||||
|
self.session.flush()
|
||||||
|
logger.info("Invalidated %d cache entries for ticker %s", count, ticker)
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_TTL_HOURS = {
|
||||||
|
"get_stock_data": 1,
|
||||||
|
"get_indicators": 1,
|
||||||
|
"get_fundamentals": 24,
|
||||||
|
"get_balance_sheet": 168,
|
||||||
|
"get_cashflow": 168,
|
||||||
|
"get_income_statement": 168,
|
||||||
|
"get_news": 1,
|
||||||
|
"get_global_news": 1,
|
||||||
|
"get_insider_sentiment": 24,
|
||||||
|
"get_insider_transactions": 24,
|
||||||
|
"get_bulk_news": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_ttl(method: str) -> int:
|
||||||
|
return DEFAULT_TTL_HOURS.get(method, 24)
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from tradingagents.agents.discovery import NewsArticle
|
from tradingagents.agents.discovery import NewsArticle
|
||||||
|
|
||||||
|
|
@ -139,6 +139,52 @@ _bulk_news_cache: dict[str, dict[str, Any]] = {}
|
||||||
_bulk_news_cache_lock = threading.Lock()
|
_bulk_news_cache_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_db_cache_enabled() -> bool:
|
||||||
|
config = get_config()
|
||||||
|
return config.get("database_enabled", False) and config.get(
|
||||||
|
"db_cache_enabled", True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_db_cached_data(method: str, ticker: str | None = None, **kwargs: Any) -> Any:
|
||||||
|
if not _is_db_cache_enabled():
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tradingagents.database import MarketDataService, get_db_session
|
||||||
|
|
||||||
|
with get_db_session() as session:
|
||||||
|
service = MarketDataService(session)
|
||||||
|
return service.get_cached_data(method, ticker, **kwargs)
|
||||||
|
except (ImportError, RuntimeError, ConnectionError) as e:
|
||||||
|
logger.debug("DB cache lookup failed: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _set_db_cached_data(
|
||||||
|
method: str,
|
||||||
|
data: Any,
|
||||||
|
ticker: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
if not _is_db_cache_enabled():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tradingagents.database import (
|
||||||
|
MarketDataService,
|
||||||
|
get_db_session,
|
||||||
|
get_default_ttl,
|
||||||
|
)
|
||||||
|
|
||||||
|
ttl_hours = get_default_ttl(method)
|
||||||
|
with get_db_session() as session:
|
||||||
|
service = MarketDataService(session)
|
||||||
|
service.set_cached_data(method, data, ticker, ttl_hours=ttl_hours, **kwargs)
|
||||||
|
except (ImportError, RuntimeError, ConnectionError) as e:
|
||||||
|
logger.debug("DB cache write failed: %s", e)
|
||||||
|
|
||||||
|
|
||||||
def parse_lookback_period(lookback: str) -> int:
|
def parse_lookback_period(lookback: str) -> int:
|
||||||
lookback = lookback.lower().strip()
|
lookback = lookback.lower().strip()
|
||||||
|
|
||||||
|
|
@ -277,6 +323,14 @@ def get_vendor(category: str, method: str = None) -> str:
|
||||||
|
|
||||||
|
|
||||||
def route_to_vendor(method: str, *args, **kwargs):
|
def route_to_vendor(method: str, *args, **kwargs):
|
||||||
|
ticker = kwargs.get("ticker") or (args[0] if args else None)
|
||||||
|
date = kwargs.get("date") or (args[1] if len(args) > 1 else None)
|
||||||
|
|
||||||
|
cached_result = _get_db_cached_data(method, ticker, date=date)
|
||||||
|
if cached_result is not None:
|
||||||
|
logger.info("DB cache hit for %s (ticker=%s)", method, ticker)
|
||||||
|
return cached_result
|
||||||
|
|
||||||
category = get_category_for_method(method)
|
category = get_category_for_method(method)
|
||||||
vendor_config = get_vendor(category, method)
|
vendor_config = get_vendor(category, method)
|
||||||
|
|
||||||
|
|
@ -407,6 +461,10 @@ def route_to_vendor(method: str, *args, **kwargs):
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(results) == 1:
|
if len(results) == 1:
|
||||||
return results[0]
|
final_result = results[0]
|
||||||
else:
|
else:
|
||||||
return "\n".join(str(result) for result in results)
|
final_result = "\n".join(str(result) for result in results)
|
||||||
|
|
||||||
|
_set_db_cached_data(method, final_result, ticker, date=date)
|
||||||
|
|
||||||
|
return final_result
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue