TradingAgents/tradingagents/services/market_data_service.py

347 lines
13 KiB
Python

"""
Market data service that provides structured market context.
"""
import logging
from typing import Any
from tradingagents.clients.base import BaseClient
from tradingagents.dataflows.stockstats_utils import StockstatsUtils
from tradingagents.models.context import (
MarketDataContext,
TechnicalIndicatorData,
)
from tradingagents.repositories.base import BaseRepository
from .base import BaseService
logger = logging.getLogger(__name__)
class MarketDataService(BaseService):
"""Service for market data and technical indicators."""
def __init__(
self,
client: BaseClient | None = None,
repository: BaseRepository | None = None,
online_mode: bool = True,
**kwargs,
):
"""
Initialize market data service.
Args:
client: Client for live market data
repository: Repository for historical market data
online_mode: Whether to use live data
**kwargs: Additional configuration
"""
super().__init__(online_mode, **kwargs)
self.client = client
self.repository = repository
self.stockstats_utils = StockstatsUtils()
def get_context(
self,
symbol: str,
start_date: str,
end_date: str,
indicators: list[str] | None = None,
force_refresh: bool = False,
**kwargs,
) -> MarketDataContext:
"""
Get market data context with price data and technical indicators.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
indicators: List of technical indicators to calculate
force_refresh: If True, skip local data and fetch fresh from API
**kwargs: Additional parameters
Returns:
MarketDataContext: Structured market data context
"""
if indicators is None:
indicators = ["rsi", "macd", "close_50_sma"]
# Local-first data strategy with force refresh option
if force_refresh:
# Skip local data, fetch fresh from API
price_data = self._fetch_and_cache_fresh_data(symbol, start_date, end_date)
data_source = "live_api_refresh"
else:
# Check local data first, fetch missing if needed
price_data = self._get_price_data_local_first(symbol, start_date, end_date)
data_source = price_data.get("metadata", {}).get("source", "unknown")
# Calculate technical indicators
technical_indicators = self._calculate_indicators(
symbol, start_date, end_date, indicators
)
# Determine data quality
data_quality = self._determine_data_quality(
data_source=data_source,
record_count=len(price_data.get("data", [])),
has_errors="error" in price_data.get("metadata", {}),
)
# Create metadata
metadata = self._create_base_metadata(
data_quality=data_quality,
price_data_source=data_source,
indicator_count=len(technical_indicators),
symbol=symbol,
force_refresh=force_refresh,
)
return MarketDataContext(
symbol=symbol,
period={"start": start_date, "end": end_date},
price_data=price_data.get("data", []),
technical_indicators=technical_indicators,
metadata=metadata,
)
def get_price_context(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> MarketDataContext:
"""
Get market data context with just price data (no indicators).
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
**kwargs: Additional parameters
Returns:
MarketDataContext: Market context with price data only
"""
return self.get_context(symbol, start_date, end_date, indicators=[], **kwargs)
def _get_price_data_local_first(
self, symbol: str, start_date: str, end_date: str
) -> dict[str, Any]:
"""Get price data using local-first strategy: check local data first, fetch missing if needed."""
try:
# Check if we have sufficient local data
if self.repository and self.repository.has_data_for_period(
symbol, start_date, end_date
):
logger.info(
f"Using local data for {symbol} ({start_date} to {end_date})"
)
local_data = self.repository.get_data(
symbol=symbol, start_date=start_date, end_date=end_date
)
local_data["metadata"] = local_data.get("metadata", {})
local_data["metadata"]["source"] = "local_cache"
return local_data
# We don't have sufficient local data - need to fetch from API
if self.client:
logger.info(
f"Local data insufficient, fetching from API for {symbol} ({start_date} to {end_date})"
)
fresh_data = self.client.get_data(
symbol=symbol, start_date=start_date, end_date=end_date
)
# Cache the fresh data if we have a repository
if fresh_data and self.repository:
try:
self.repository.store_data(symbol, fresh_data)
logger.debug(f"Cached fresh data for {symbol}")
except Exception as e:
logger.warning(f"Failed to cache data for {symbol}: {e}")
fresh_data["metadata"] = fresh_data.get("metadata", {})
fresh_data["metadata"]["source"] = "live_api"
return fresh_data
# No client available, try repository as fallback
elif self.repository:
logger.warning(
f"No API client available, using partial local data for {symbol}"
)
local_data = self.repository.get_data(
symbol=symbol, start_date=start_date, end_date=end_date
)
local_data["metadata"] = local_data.get("metadata", {})
local_data["metadata"]["source"] = "local_partial"
return local_data
else:
logger.warning(f"No data source available for {symbol}")
return {
"symbol": symbol,
"data": [],
"metadata": {
"source": "none",
"error": "No client or repository configured",
},
}
except Exception as e:
logger.error(f"Error fetching price data for {symbol}: {e}")
return {
"symbol": symbol,
"data": [],
"metadata": {"source": "error", "error": str(e)},
}
def _fetch_and_cache_fresh_data(
self, symbol: str, start_date: str, end_date: str
) -> dict[str, Any]:
"""Force fetch fresh data from API and cache it, bypassing local data."""
try:
if not self.client:
logger.warning(f"No API client available for force refresh of {symbol}")
return {
"symbol": symbol,
"data": [],
"metadata": {
"source": "no_client",
"error": "No API client configured for force refresh",
},
}
logger.info(
f"Force refreshing data from API for {symbol} ({start_date} to {end_date})"
)
# Clear existing data if we have a repository
if self.repository:
try:
self.repository.clear_data(symbol, start_date, end_date)
logger.debug(f"Cleared existing data for {symbol}")
except Exception as e:
logger.warning(f"Failed to clear existing data for {symbol}: {e}")
# Fetch fresh data
fresh_data = self.client.get_data(
symbol=symbol, start_date=start_date, end_date=end_date
)
# Cache the fresh data
if fresh_data and self.repository:
try:
self.repository.store_data(symbol, fresh_data, overwrite=True)
logger.debug(f"Cached refreshed data for {symbol}")
except Exception as e:
logger.warning(f"Failed to cache refreshed data for {symbol}: {e}")
fresh_data["metadata"] = fresh_data.get("metadata", {})
fresh_data["metadata"]["source"] = "live_api_refresh"
return fresh_data
except Exception as e:
logger.error(f"Error force refreshing data for {symbol}: {e}")
return {
"symbol": symbol,
"data": [],
"metadata": {"source": "refresh_error", "error": str(e)},
}
def _calculate_indicators(
self, symbol: str, start_date: str, end_date: str, indicators: list[str]
) -> dict[str, list[TechnicalIndicatorData]]:
"""Calculate technical indicators."""
if not indicators:
return {}
technical_data = {}
for indicator in indicators:
try:
logger.info(f"Calculating {indicator} for {symbol}")
# Use existing stockstats utility
indicator_data = self._get_indicator_data(
symbol, indicator, start_date, end_date
)
if indicator_data:
technical_data[indicator] = indicator_data
else:
logger.warning(f"No data returned for indicator {indicator}")
except Exception as e:
logger.error(f"Error calculating {indicator} for {symbol}: {e}")
continue
return technical_data
def _get_indicator_data(
self, symbol: str, indicator: str, start_date: str, end_date: str
) -> list[TechnicalIndicatorData]:
"""Get indicator data using StockstatsUtils."""
try:
from datetime import datetime, timedelta
# Get data for the date range
current_date = datetime.strptime(end_date, "%Y-%m-%d")
start_date_dt = datetime.strptime(start_date, "%Y-%m-%d")
indicator_points = []
# Iterate through date range
while current_date >= start_date_dt:
date_str = current_date.strftime("%Y-%m-%d")
try:
# Use stockstats utility to get indicator value
# This assumes the existing data directory structure
data_dir = self.config.get("data_dir", "data")
price_data_dir = f"{data_dir}/market_data/price_data"
indicator_value = StockstatsUtils.get_stock_stats(
symbol,
indicator,
date_str,
price_data_dir,
online=self.online_mode,
)
if indicator_value is not None and indicator_value != "":
# Handle different indicator value types
if isinstance(indicator_value, int | float):
value = float(indicator_value)
elif isinstance(indicator_value, str):
try:
value = float(indicator_value)
except ValueError:
logger.warning(
f"Could not parse indicator value: {indicator_value}"
)
current_date -= timedelta(days=1)
continue
else:
# For complex indicators like MACD, this might be a dict
value = indicator_value
indicator_points.append(
TechnicalIndicatorData(
date=date_str, value=value, indicator_type=indicator
)
)
except Exception as e:
logger.debug(
f"Could not get {indicator} for {symbol} on {date_str}: {e}"
)
current_date -= timedelta(days=1)
# Return in chronological order
return list(reversed(indicator_points))
except Exception as e:
logger.error(f"Error getting indicator data for {indicator}: {e}")
return []