TradingAgents/tradingagents/ml/feature_engineering.py

360 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Shared feature extraction for ML model — used by both training and inference.
All 20 features are computed locally from OHLCV data via stockstats + pandas.
Zero API calls required for indicator computation.
"""
from __future__ import annotations
from typing import Dict, List, Optional
import numpy as np
import pandas as pd
from stockstats import wrap
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
# Canonical feature list — order matters for model consistency
FEATURE_COLUMNS: List[str] = [
# Base indicators (20)
"rsi_14",
"macd",
"macd_signal",
"macd_hist",
"atr_pct",
"bb_width_pct",
"bb_position",
"adx",
"mfi",
"stoch_k",
"volume_ratio_5d",
"volume_ratio_20d",
"return_1d",
"return_5d",
"return_20d",
"sma50_distance",
"sma200_distance",
"high_low_range",
"gap_pct",
"log_market_cap",
# Interaction & derived features (10)
"momentum_x_compression", # strong trend + tight bands = breakout signal
"rsi_momentum", # RSI rate of change (acceleration)
"volume_price_confirm", # volume surge + positive return = confirmed move
"trend_alignment", # SMA50 and SMA200 agree on direction
"volatility_regime", # ATR percentile rank (0-1) within own history
"mean_reversion_signal", # oversold RSI + below lower BB
"breakout_signal", # above upper BB + high volume
"macd_strength", # MACD histogram normalized by ATR
"return_volatility_ratio", # Sharpe-like: return_5d / atr_pct
"trend_momentum_score", # combined trend + momentum z-score
]
# Minimum rows of OHLCV history needed before features are valid
# (200-day SMA needs 200 rows of prior data)
MIN_HISTORY_ROWS = 210
def compute_features_bulk(ohlcv: pd.DataFrame, market_cap: Optional[float] = None) -> pd.DataFrame:
"""Compute all 20 ML features for every row in an OHLCV DataFrame.
Args:
ohlcv: DataFrame with columns: Date, Open, High, Low, Close, Volume.
Must be sorted by Date ascending.
market_cap: Market capitalization in USD. If None, log_market_cap = NaN.
Returns:
DataFrame indexed by Date with one column per feature.
Rows with insufficient history (first ~210) will have NaN values.
"""
if ohlcv.empty or len(ohlcv) < MIN_HISTORY_ROWS:
return pd.DataFrame(columns=FEATURE_COLUMNS)
df = ohlcv.copy()
# Ensure Date column is available and set as index
if "Date" in df.columns:
df["Date"] = pd.to_datetime(df["Date"])
df = df.set_index("Date").sort_index()
elif not isinstance(df.index, pd.DatetimeIndex):
df.index = pd.to_datetime(df.index)
df = df.sort_index()
# Normalize column names (yfinance sometimes returns Title Case)
col_map = {}
for col in df.columns:
lower = col.lower()
if lower == "open":
col_map[col] = "Open"
elif lower == "high":
col_map[col] = "High"
elif lower == "low":
col_map[col] = "Low"
elif lower in ("close", "adj close"):
col_map[col] = "Close"
elif lower == "volume":
col_map[col] = "Volume"
df = df.rename(columns=col_map)
# Need these columns
for required in ("Open", "High", "Low", "Close", "Volume"):
if required not in df.columns:
logger.warning(f"Missing column {required} in OHLCV data")
return pd.DataFrame(columns=FEATURE_COLUMNS)
close = df["Close"]
volume = df["Volume"]
# --- Stockstats indicators ---
ss = wrap(df.copy())
features = pd.DataFrame(index=df.index)
# 1. RSI (14-period)
features["rsi_14"] = ss["rsi_14"]
# 2-4. MACD (12, 26, 9)
features["macd"] = ss["macd"]
features["macd_signal"] = ss["macds"]
features["macd_hist"] = ss["macdh"]
# 5. ATR as percentage of price
atr = ss["atr_14"]
features["atr_pct"] = (atr / close) * 100
# 6. Bollinger Band width as percentage
bb_upper = ss["boll_ub"]
bb_lower = ss["boll_lb"]
bb_middle = ss["boll"]
features["bb_width_pct"] = ((bb_upper - bb_lower) / bb_middle) * 100
# 7. Position within Bollinger Bands (0 = lower band, 1 = upper band)
bb_range = bb_upper - bb_lower
features["bb_position"] = np.where(bb_range > 0, (close - bb_lower) / bb_range, 0.5)
# 8. ADX (trend strength)
features["adx"] = ss["dx_14"]
# 9. Money Flow Index
features["mfi"] = ss["mfi_14"]
# 10. Stochastic %K
features["stoch_k"] = ss["kdjk"]
# --- Pandas-computed features ---
# 11-12. Volume ratios
vol_ma_5 = volume.rolling(5).mean()
vol_ma_20 = volume.rolling(20).mean()
features["volume_ratio_5d"] = volume / vol_ma_5.replace(0, np.nan)
features["volume_ratio_20d"] = volume / vol_ma_20.replace(0, np.nan)
# 13-15. Historical returns (looking backward — no data leakage)
features["return_1d"] = close.pct_change(1, fill_method=None) * 100
features["return_5d"] = close.pct_change(5, fill_method=None) * 100
features["return_20d"] = close.pct_change(20, fill_method=None) * 100
# 16-17. Distance from moving averages
sma_50 = close.rolling(50).mean()
sma_200 = close.rolling(200).mean()
features["sma50_distance"] = ((close - sma_50) / sma_50) * 100
features["sma200_distance"] = ((close - sma_200) / sma_200) * 100
# 18. High-Low range as percentage of close
features["high_low_range"] = ((df["High"] - df["Low"]) / close) * 100
# 19. Gap percentage (open vs previous close)
prev_close = close.shift(1)
features["gap_pct"] = ((df["Open"] - prev_close) / prev_close) * 100
# 20. Log market cap (static per stock)
if market_cap and market_cap > 0:
features["log_market_cap"] = np.log10(market_cap)
else:
features["log_market_cap"] = np.nan
# --- Interaction & derived features (10) ---
# 21. Momentum × Compression: strong trend direction + tight Bollinger = breakout setup
# High absolute MACD + low BB width = coiled spring
features["momentum_x_compression"] = features["macd_hist"].abs() / features[
"bb_width_pct"
].replace(0, np.nan)
# 22. RSI momentum: 5-day rate of change of RSI (acceleration of momentum)
features["rsi_momentum"] = features["rsi_14"] - features["rsi_14"].shift(5)
# 23. Volume-price confirmation: volume surge accompanied by price move
features["volume_price_confirm"] = features["volume_ratio_5d"] * features["return_1d"]
# 24. Trend alignment: both SMAs agree (1 = aligned bullish, -1 = aligned bearish)
features["trend_alignment"] = np.sign(features["sma50_distance"]) * np.sign(
features["sma200_distance"]
)
# 25. Volatility regime: ATR percentile within rolling 60-day window (0-1)
atr_pct_series = features["atr_pct"]
features["volatility_regime"] = atr_pct_series.rolling(60).apply(
lambda x: (x.iloc[-1] - x.min()) / (x.max() - x.min()) if x.max() != x.min() else 0.5,
raw=False,
)
# 26. Mean reversion signal: oversold RSI + price below lower Bollinger
features["mean_reversion_signal"] = (
(100 - features["rsi_14"]) / 100 # inversed RSI (higher = more oversold)
) * (
1 - features["bb_position"].clip(0, 1)
) # below lower band amplifies signal
# 27. Breakout signal: above upper BB + high volume ratio
features["breakout_signal"] = features["bb_position"].clip(0, 2) * features["volume_ratio_20d"]
# 28. MACD strength: histogram normalized by volatility
features["macd_strength"] = features["macd_hist"] / features["atr_pct"].replace(0, np.nan)
# 29. Return/Volatility ratio: Sharpe-like metric
features["return_volatility_ratio"] = features["return_5d"] / features["atr_pct"].replace(
0, np.nan
)
# 30. Trend-momentum composite score
features["trend_momentum_score"] = (
features["sma50_distance"] * 0.4
+ features["rsi_14"].sub(50) * 0.3 # RSI centered at 50
+ features["macd_hist"] * 0.3
)
return features[FEATURE_COLUMNS]
def compute_features_single(
ohlcv: pd.DataFrame,
date: str,
market_cap: Optional[float] = None,
) -> Optional[Dict[str, float]]:
"""Compute features for a single date. Used during live inference.
Args:
ohlcv: Full OHLCV DataFrame (needs ~210 rows of history before `date`).
date: Target date string (YYYY-MM-DD).
market_cap: Market cap in USD.
Returns:
Dict mapping feature name → value, or None if insufficient data.
"""
features_df = compute_features_bulk(ohlcv, market_cap=market_cap)
if features_df.empty:
return None
date_ts = pd.Timestamp(date)
# Find the closest date on or before the target
valid = features_df.index[features_df.index <= date_ts]
if len(valid) == 0:
return None
row = features_df.loc[valid[-1]]
if row.isna().all():
return None
return row.to_dict()
def compute_features_from_enriched_candidate(cand: Dict) -> Optional[Dict[str, float]]:
"""Extract ML features from an already-enriched discovery candidate.
During live inference, the enrichment pipeline has already computed
many of the values we need. This avoids redundant computation.
Args:
cand: Enriched candidate dict from filter.py.
Returns:
Dict of feature values, or None if critical fields are missing.
"""
features: Dict[str, float] = {}
# Features already available on enriched candidates
features["rsi_14"] = cand.get("rsi_value", np.nan)
features["atr_pct"] = cand.get("atr_pct", np.nan)
features["bb_width_pct"] = cand.get("bb_width_pct", np.nan)
features["volume_ratio_20d"] = cand.get("volume_ratio", np.nan)
# Market cap
market_cap_bil = cand.get("market_cap_bil")
if market_cap_bil and market_cap_bil > 0:
features["log_market_cap"] = np.log10(market_cap_bil * 1e9)
else:
features["log_market_cap"] = np.nan
# Intraday return as proxy for return_1d
features["return_1d"] = cand.get("intraday_change_pct", np.nan)
# Short interest as a signal (use as proxy where we lack full OHLCV)
short_pct = cand.get("short_interest_pct")
if short_pct is not None:
features["log_market_cap"] = features.get("log_market_cap", np.nan)
# For features not directly available on enriched candidates,
# we need to fetch OHLCV and compute. This is the "full path".
# Return None to signal the caller should use compute_features_single() instead.
missing = [f for f in FEATURE_COLUMNS if f not in features or np.isnan(features.get(f, np.nan))]
if len(missing) > 5:
# Too many missing — need full OHLCV computation
return None
# Fill remaining with NaN (TabPFN handles missing values)
for col in FEATURE_COLUMNS:
if col not in features:
features[col] = np.nan
return features
def apply_triple_barrier_labels(
close_prices: pd.Series,
profit_target: float = 0.05,
stop_loss: float = 0.03,
max_holding_days: int = 7,
) -> pd.Series:
"""Apply triple-barrier labeling to a series of close prices.
For each day, looks forward up to `max_holding_days` trading days:
+1 (WIN): Price hits +profit_target first
-1 (LOSS): Price hits -stop_loss first
0 (TIMEOUT): Neither barrier hit within the window
Args:
close_prices: Series of daily close prices, indexed by date.
profit_target: Upside target as fraction (0.05 = 5%).
stop_loss: Downside limit as fraction (0.03 = 3%).
max_holding_days: Maximum forward-looking trading days.
Returns:
Series of labels (+1, -1, 0) aligned with close_prices index.
Last `max_holding_days` rows will be NaN (can't look forward).
"""
prices = close_prices.values
n = len(prices)
labels = np.full(n, np.nan)
for i in range(n - max_holding_days):
entry = prices[i]
upper = entry * (1 + profit_target)
lower = entry * (1 - stop_loss)
label = 0 # default: timeout
for j in range(1, max_holding_days + 1):
future_price = prices[i + j]
if future_price >= upper:
label = 1 # hit profit target
break
elif future_price <= lower:
label = -1 # hit stop loss
break
labels[i] = label
return pd.Series(labels, index=close_prices.index, name="label")