TradingAgents/tradingagents/dataflows/stockstats_utils.py

289 lines
9.9 KiB
Python

import time
import logging
import threading
import pandas as pd
import yfinance as yf
import requests
from yfinance.exceptions import YFRateLimitError
from stockstats import wrap
from typing import Annotated
import os
from .config import get_config
logger = logging.getLogger(__name__)
_fallback_session_local = threading.local()
def _get_fallback_session() -> requests.Session:
session = getattr(_fallback_session_local, "session", None)
if session is None:
session = requests.Session()
session.trust_env = False
_fallback_session_local.session = session
return session
def _symbol_to_tencent_code(symbol: str) -> str:
code, exchange = symbol.upper().split(".")
if exchange == "SS":
return f"sh{code}"
if exchange == "SZ":
return f"sz{code}"
raise ValueError(f"Unsupported A-share symbol for Tencent fallback: {symbol}")
def _fetch_tencent_ohlcv(symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
"""Fallback daily OHLCV fetch for A-shares via Tencent."""
session = _get_fallback_session()
response = session.get(
"https://web.ifzq.gtimg.cn/appstock/app/fqkline/get",
params={
"param": f"{_symbol_to_tencent_code(symbol)},day,{start_date},{end_date},320,qfq"
},
headers={
"User-Agent": "Mozilla/5.0",
"Referer": "https://gu.qq.com/",
},
timeout=20,
)
response.raise_for_status()
payload = response.json()
data = ((payload or {}).get("data") or {}).get(_symbol_to_tencent_code(symbol)) or {}
rows = data.get("qfqday") or data.get("day") or []
if not rows:
raise ValueError(f"No Tencent OHLCV data returned for {symbol}")
parsed = []
for line in rows:
# [date, open, close, high, low, volume]
date_str, open_p, close_p, high_p, low_p, volume = line[:6]
parsed.append(
{
"Date": date_str,
"Open": float(open_p),
"High": float(high_p),
"Low": float(low_p),
"Close": float(close_p),
"Volume": float(volume),
}
)
return pd.DataFrame(parsed)
def _symbol_to_eastmoney_secid(symbol: str) -> str:
code, exchange = symbol.upper().split(".")
if exchange == "SS":
return f"1.{code}"
if exchange in {"SZ", "BJ"}:
return f"0.{code}"
raise ValueError(f"Unsupported A-share symbol for Eastmoney fallback: {symbol}")
def _fetch_eastmoney_ohlcv(symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
"""Fallback daily OHLCV fetch for A-shares via Eastmoney."""
session = _get_fallback_session()
url = "https://push2his.eastmoney.com/api/qt/stock/kline/get"
response = session.get(
url,
params={
"secid": _symbol_to_eastmoney_secid(symbol),
"fields1": "f1,f2,f3,f4,f5,f6",
"fields2": "f51,f52,f53,f54,f55,f56,f57,f58,f59,f60,f61",
"klt": "101",
"fqt": "1",
"beg": start_date.replace("-", ""),
"end": end_date.replace("-", ""),
"ut": "fa5fd1943c7b386f172d6893dbfba10b",
},
headers={
"User-Agent": "Mozilla/5.0",
"Referer": "https://quote.eastmoney.com/",
},
timeout=20,
)
response.raise_for_status()
payload = response.json()
klines = ((payload or {}).get("data") or {}).get("klines") or []
if not klines:
raise ValueError(f"No Eastmoney OHLCV data returned for {symbol}")
rows = []
for line in klines:
date_str, open_p, close_p, high_p, low_p, volume, amount, *_rest = line.split(",")
rows.append(
{
"Date": date_str,
"Open": float(open_p),
"High": float(high_p),
"Low": float(low_p),
"Close": float(close_p),
"Volume": float(volume),
"Amount": float(amount),
}
)
return pd.DataFrame(rows)
def _is_transient_yfinance_error(exc: Exception) -> bool:
"""Heuristic for flaky yfinance transport/parser failures."""
if isinstance(exc, YFRateLimitError):
return True
message = str(exc)
return isinstance(exc, TypeError) and "'NoneType' object is not subscriptable" in message
def yf_retry(func, max_retries=3, base_delay=2.0):
"""Execute a yfinance call with exponential backoff on rate limits.
yfinance raises YFRateLimitError on HTTP 429 responses but does not
retry them internally. This wrapper adds retry logic specifically
for rate limits and observed transient parser failures. Other
exceptions propagate immediately.
"""
for attempt in range(max_retries + 1):
try:
return func()
except Exception as exc:
if not _is_transient_yfinance_error(exc):
raise
if attempt < max_retries:
delay = base_delay * (2 ** attempt)
logger.warning(
"Yahoo Finance transient failure (%s), retrying in %.0fs (attempt %s/%s)",
exc,
delay,
attempt + 1,
max_retries,
)
time.sleep(delay)
else:
raise
def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
"""Normalize a stock DataFrame for stockstats: parse dates, drop invalid rows, fill price gaps."""
data["Date"] = pd.to_datetime(data["Date"], errors="coerce")
data = data.dropna(subset=["Date"])
price_cols = [c for c in ["Open", "High", "Low", "Close", "Volume"] if c in data.columns]
data[price_cols] = data[price_cols].apply(pd.to_numeric, errors="coerce")
data = data.dropna(subset=["Close"])
data[price_cols] = data[price_cols].ffill().bfill()
return data
def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
"""Fetch OHLCV data with caching, filtered to prevent look-ahead bias.
Downloads 15 years of data up to today and caches per symbol. On
subsequent calls the cache is reused. Rows after curr_date are
filtered out so backtests never see future prices.
"""
config = get_config()
curr_date_dt = pd.to_datetime(curr_date)
min_acceptable_date = curr_date_dt - pd.Timedelta(days=1)
# Cache uses a fixed window (15y to today) so one file per symbol
today_date = pd.Timestamp.today()
start_date = today_date - pd.DateOffset(years=5)
start_str = start_date.strftime("%Y-%m-%d")
end_str = today_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_str}-{end_str}.csv",
)
need_refresh = True
data = None
if os.path.exists(data_file):
cached = pd.read_csv(data_file, on_bad_lines="skip")
if "Date" in cached.columns:
parsed_dates = pd.to_datetime(cached["Date"], errors="coerce")
latest_cached = parsed_dates.dropna().max()
if (
latest_cached is not pd.NaT
and latest_cached is not None
and latest_cached >= min_acceptable_date
):
data = cached
need_refresh = False
if need_refresh:
try:
data = yf_retry(lambda: yf.download(
symbol,
start=start_str,
end=end_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
))
data = data.reset_index()
latest_downloaded = pd.to_datetime(data.get("Date"), errors="coerce").dropna().max()
if latest_downloaded is pd.NaT or latest_downloaded is None or latest_downloaded < min_acceptable_date:
raise ValueError(
f"yfinance returned stale data for {symbol}: latest={latest_downloaded}"
)
except Exception as exc:
logger.warning(
"yfinance download failed for %s, falling back to Tencent/Eastmoney OHLCV: %s",
symbol,
exc,
)
try:
data = _fetch_tencent_ohlcv(symbol, start_str, end_str)
except Exception:
data = _fetch_eastmoney_ohlcv(symbol, start_str, end_str)
data.to_csv(data_file, index=False)
data = _clean_dataframe(data)
# Filter to curr_date to prevent look-ahead bias in backtesting
data = data[data["Date"] <= curr_date_dt]
return data
def filter_financials_by_date(data: pd.DataFrame, curr_date: str) -> pd.DataFrame:
"""Drop financial statement columns (fiscal period timestamps) after curr_date.
yfinance financial statements use fiscal period end dates as columns.
Columns after curr_date represent future data and are removed to
prevent look-ahead bias.
"""
if not curr_date or data.empty:
return data
cutoff = pd.Timestamp(curr_date)
mask = pd.to_datetime(data.columns, errors="coerce") <= cutoff
return data.loc[:, mask]
class StockstatsUtils:
@staticmethod
def get_stock_stats(
symbol: Annotated[str, "ticker symbol for the company"],
indicator: Annotated[
str, "quantitative indicators based off of the stock data for the company"
],
curr_date: Annotated[
str, "curr date for retrieving stock price data, YYYY-mm-dd"
],
):
data = load_ohlcv(symbol, curr_date)
df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
curr_date_str = pd.to_datetime(curr_date).strftime("%Y-%m-%d")
df[indicator] # trigger stockstats to calculate the indicator
matching_rows = df[df["Date"].str.startswith(curr_date_str)]
if not matching_rows.empty:
indicator_value = matching_rows[indicator].values[0]
return indicator_value
else:
return "N/A: Not a trading day (weekend or holiday)"