TradingAgents/tradingagents/dataflows/stockstats_utils.py

87 lines
3.1 KiB
Python

import pandas as pd
import yfinance as yf
from stockstats import wrap
from typing import Annotated
import os
from .config import get_config, DATA_DIR
class StockstatsUtils:
@staticmethod
def get_stock_stats(
symbol: Annotated[str, "ticker symbol for the company"],
indicator: Annotated[
str, "quantitative indicators based off of the stock data for the company"
],
curr_date: Annotated[
str, "curr date for retrieving stock price data, YYYY-mm-dd"
],
):
# Get config and set up data directory path
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
df = None
data = None
# CRITICAL: Use curr_date as end date to prevent future data leakage
# This ensures backtest doesn't see data beyond the analysis date
curr_date_dt = pd.to_datetime(curr_date)
end_date_dt = curr_date_dt
start_date_dt = curr_date_dt - pd.DateOffset(years=2) # Reduced from 15 years for faster fetching
if not online:
try:
data = pd.read_csv(
os.path.join(
DATA_DIR,
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
# CRITICAL: Filter local data to prevent future data leakage
data["Date"] = pd.to_datetime(data["Date"])
data = data[data["Date"] <= curr_date_dt]
df = wrap(data)
except FileNotFoundError:
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
else:
start_date = start_date_dt.strftime("%Y-%m-%d")
end_date = end_date_dt.strftime("%Y-%m-%d")
# Get config and ensure cache directory exists
os.makedirs(config["data_cache_dir"], exist_ok=True)
# Cache file now uses curr_date (end_date), not today's date
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date}-{end_date}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
data["Date"] = pd.to_datetime(data["Date"])
else:
data = yf.download(
symbol,
start=start_date,
end=end_date,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
curr_date = curr_date_dt.strftime("%Y-%m-%d")
df[indicator] # trigger stockstats to calculate the indicator
matching_rows = df[df["Date"].str.startswith(curr_date)]
if not matching_rows.empty:
indicator_value = matching_rows[indicator].values[0]
return indicator_value
else:
return "N/A: Not a trading day (weekend or holiday)"