300 lines
11 KiB
Python
300 lines
11 KiB
Python
"""ML signal scanner — surfaces high P(WIN) setups from a ticker universe.
|
|
|
|
Universe is loaded from a text file (one ticker per line, # comments allowed).
|
|
Default: data/tickers.txt. Override via config: discovery.scanners.ml_signal.ticker_file
|
|
"""
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import pandas as pd
|
|
|
|
from tradingagents.dataflows.discovery.scanner_registry import SCANNER_REGISTRY, BaseScanner
|
|
from tradingagents.dataflows.discovery.utils import Priority
|
|
from tradingagents.utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# Default ticker file path (relative to project root)
|
|
DEFAULT_TICKER_FILE = "data/tickers.txt"
|
|
|
|
|
|
def _load_tickers_from_file(path: str) -> List[str]:
|
|
"""Load ticker symbols from a text file (one per line, # comments allowed)."""
|
|
try:
|
|
with open(path) as f:
|
|
tickers = [
|
|
line.strip().upper()
|
|
for line in f
|
|
if line.strip() and not line.strip().startswith("#")
|
|
]
|
|
if tickers:
|
|
logger.info(f"ML scanner: loaded {len(tickers)} tickers from {path}")
|
|
return tickers
|
|
except FileNotFoundError:
|
|
logger.warning(f"Ticker file not found: {path}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load ticker file {path}: {e}")
|
|
return []
|
|
|
|
|
|
class MLSignalScanner(BaseScanner):
|
|
"""Scan a ticker universe for high ML win-probability setups.
|
|
|
|
Loads the trained LightGBM/TabPFN model, fetches recent OHLCV data
|
|
for a universe of tickers, computes technical features, and returns
|
|
candidates whose predicted P(WIN) exceeds a configurable threshold.
|
|
|
|
Optimized for large universes (500+ tickers):
|
|
- Single batch yfinance download (1 HTTP request)
|
|
- Parallel feature computation via ThreadPoolExecutor
|
|
- Market cap skipped by default (1 NaN feature out of 30)
|
|
"""
|
|
|
|
name = "ml_signal"
|
|
pipeline = "momentum"
|
|
strategy = "ml_signal"
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
super().__init__(config)
|
|
self.min_win_prob = self.scanner_config.get("min_win_prob", 0.50)
|
|
self.lookback_period = self.scanner_config.get("lookback_period", "1y")
|
|
self.max_workers = self.scanner_config.get("max_workers", 8)
|
|
self.fetch_market_cap = self.scanner_config.get("fetch_market_cap", False)
|
|
|
|
# Load universe: config list > config file > default tickers file
|
|
if "ticker_universe" in self.scanner_config:
|
|
self.universe = self.scanner_config["ticker_universe"]
|
|
else:
|
|
ticker_file = self.scanner_config.get(
|
|
"ticker_file",
|
|
config.get("tickers_file", DEFAULT_TICKER_FILE),
|
|
)
|
|
self.universe = _load_tickers_from_file(ticker_file)
|
|
if not self.universe:
|
|
logger.warning(f"No tickers loaded from {ticker_file} — scanner will be empty")
|
|
|
|
def scan(self, state: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
if not self.is_enabled():
|
|
return []
|
|
|
|
logger.info(
|
|
f"Running ML signal scanner on {len(self.universe)} tickers "
|
|
f"(min P(WIN) = {self.min_win_prob:.0%})..."
|
|
)
|
|
|
|
# 1. Load ML model
|
|
predictor = self._load_predictor()
|
|
if predictor is None:
|
|
logger.warning("No ML model available — skipping ml_signal scanner")
|
|
return []
|
|
|
|
# 2. Batch-fetch OHLCV data (single HTTP request)
|
|
ohlcv_by_ticker = self._fetch_universe_ohlcv()
|
|
if not ohlcv_by_ticker:
|
|
logger.warning("No OHLCV data fetched — skipping ml_signal scanner")
|
|
return []
|
|
|
|
# 3. Compute features and predict in parallel
|
|
candidates = self._predict_universe(predictor, ohlcv_by_ticker)
|
|
|
|
# 4. Sort by P(WIN) descending and apply limit
|
|
candidates.sort(key=lambda c: c.get("ml_win_prob", 0), reverse=True)
|
|
candidates = candidates[: self.limit]
|
|
|
|
logger.info(
|
|
f"ML signal scanner: {len(candidates)} candidates above "
|
|
f"{self.min_win_prob:.0%} threshold (from {len(ohlcv_by_ticker)} tickers)"
|
|
)
|
|
|
|
# Log individual candidate results
|
|
if candidates:
|
|
header = (
|
|
f"{'Ticker':<8} {'P(WIN)':>8} {'P(LOSS)':>9} {'Prediction':>12} {'Priority':>10}"
|
|
)
|
|
separator = "-" * len(header)
|
|
lines = ["\n ML Signal Scanner Results:", f" {header}", f" {separator}"]
|
|
for c in candidates:
|
|
lines.append(
|
|
f" {c['ticker']:<8} {c.get('ml_win_prob', 0):>7.1%} "
|
|
f"{c.get('ml_loss_prob', 0):>9.1%} "
|
|
f"{c.get('ml_prediction', 'N/A'):>12} "
|
|
f"{c.get('priority', 'N/A'):>10}"
|
|
)
|
|
lines.append(f" {separator}")
|
|
logger.info("\n".join(lines))
|
|
|
|
return candidates
|
|
|
|
def _load_predictor(self):
|
|
"""Load the trained ML model."""
|
|
try:
|
|
from tradingagents.ml.predictor import MLPredictor
|
|
|
|
return MLPredictor.load()
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load ML predictor: {e}")
|
|
return None
|
|
|
|
def _fetch_universe_ohlcv(self) -> Dict[str, pd.DataFrame]:
|
|
"""Batch-fetch OHLCV data for the entire ticker universe.
|
|
|
|
Uses yfinance batch download — a single HTTP request regardless of
|
|
universe size. This is the key optimization for large universes.
|
|
"""
|
|
try:
|
|
from tradingagents.dataflows.y_finance import download_history
|
|
|
|
logger.info(
|
|
f"Batch-downloading {len(self.universe)} tickers ({self.lookback_period})..."
|
|
)
|
|
|
|
# yfinance batch download — single HTTP request for all tickers
|
|
raw = download_history(
|
|
" ".join(self.universe),
|
|
period=self.lookback_period,
|
|
auto_adjust=True,
|
|
progress=False,
|
|
)
|
|
|
|
if raw.empty:
|
|
return {}
|
|
|
|
# Handle multi-level columns from batch download
|
|
result = {}
|
|
if isinstance(raw.columns, pd.MultiIndex):
|
|
# Multi-ticker: columns are (Price, Ticker)
|
|
tickers_in_data = raw.columns.get_level_values(1).unique()
|
|
for ticker in tickers_in_data:
|
|
try:
|
|
ticker_df = raw.xs(ticker, level=1, axis=1).copy()
|
|
ticker_df = ticker_df.reset_index()
|
|
if len(ticker_df) > 0:
|
|
result[ticker] = ticker_df
|
|
except (KeyError, ValueError):
|
|
continue
|
|
else:
|
|
# Single ticker fallback
|
|
raw = raw.reset_index()
|
|
if len(self.universe) == 1:
|
|
result[self.universe[0]] = raw
|
|
|
|
logger.info(f"Fetched OHLCV for {len(result)} tickers")
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.warning(f"OHLCV batch fetch failed: {e}")
|
|
return {}
|
|
|
|
def _predict_universe(
|
|
self, predictor, ohlcv_by_ticker: Dict[str, pd.DataFrame]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Predict P(WIN) for all tickers using parallel feature computation."""
|
|
candidates = []
|
|
|
|
if self.max_workers <= 1 or len(ohlcv_by_ticker) <= 10:
|
|
# Serial execution for small universes
|
|
for ticker, ohlcv in ohlcv_by_ticker.items():
|
|
result = self._predict_ticker(predictor, ticker, ohlcv)
|
|
if result is not None:
|
|
candidates.append(result)
|
|
else:
|
|
# Parallel feature computation for large universes
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
futures = {
|
|
executor.submit(self._predict_ticker, predictor, ticker, ohlcv): ticker
|
|
for ticker, ohlcv in ohlcv_by_ticker.items()
|
|
}
|
|
for future in as_completed(futures):
|
|
try:
|
|
result = future.result(timeout=10)
|
|
if result is not None:
|
|
candidates.append(result)
|
|
except Exception as e:
|
|
ticker = futures[future]
|
|
logger.debug(f"{ticker}: prediction timed out or failed — {e}")
|
|
|
|
return candidates
|
|
|
|
def _predict_ticker(
|
|
self, predictor, ticker: str, ohlcv: pd.DataFrame
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Compute features and predict P(WIN) for a single ticker."""
|
|
try:
|
|
from tradingagents.ml.feature_engineering import (
|
|
MIN_HISTORY_ROWS,
|
|
compute_features_single,
|
|
)
|
|
|
|
if len(ohlcv) < MIN_HISTORY_ROWS:
|
|
return None
|
|
|
|
# Market cap: skip by default for speed (1 NaN out of 30 features)
|
|
market_cap = self._get_market_cap(ticker) if self.fetch_market_cap else None
|
|
|
|
# Compute features for the most recent date
|
|
latest_date = pd.to_datetime(ohlcv["Date"]).max().strftime("%Y-%m-%d")
|
|
features = compute_features_single(ohlcv, latest_date, market_cap=market_cap)
|
|
if features is None:
|
|
return None
|
|
|
|
# Run ML prediction
|
|
prediction = predictor.predict(features)
|
|
if prediction is None:
|
|
return None
|
|
|
|
win_prob = prediction.get("win_prob", 0)
|
|
loss_prob = prediction.get("loss_prob", 0)
|
|
|
|
if win_prob < self.min_win_prob:
|
|
return None
|
|
|
|
# Determine priority from P(WIN)
|
|
if win_prob >= 0.65:
|
|
priority = Priority.CRITICAL.value
|
|
elif win_prob >= 0.55:
|
|
priority = Priority.HIGH.value
|
|
else:
|
|
priority = Priority.MEDIUM.value
|
|
|
|
return {
|
|
"ticker": ticker,
|
|
"source": self.name,
|
|
"context": (
|
|
f"ML model: {win_prob:.0%} win probability, "
|
|
f"{loss_prob:.0%} loss probability "
|
|
f"({prediction.get('prediction', 'N/A')})"
|
|
),
|
|
"priority": priority,
|
|
"strategy": self.strategy,
|
|
"ml_win_prob": win_prob,
|
|
"ml_loss_prob": loss_prob,
|
|
"ml_prediction": prediction.get("prediction", "N/A"),
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.debug(f"{ticker}: ML prediction failed — {e}")
|
|
return None
|
|
|
|
def _get_market_cap(self, ticker: str) -> Optional[float]:
|
|
"""Get market cap (best-effort, cached in memory for the scan)."""
|
|
if not hasattr(self, "_market_cap_cache"):
|
|
self._market_cap_cache: Dict[str, Optional[float]] = {}
|
|
|
|
if ticker in self._market_cap_cache:
|
|
return self._market_cap_cache[ticker]
|
|
|
|
try:
|
|
from tradingagents.dataflows.y_finance import get_ticker_info
|
|
|
|
info = get_ticker_info(ticker)
|
|
cap = info.get("marketCap")
|
|
self._market_cap_cache[ticker] = cap
|
|
return cap
|
|
except Exception:
|
|
self._market_cap_cache[ticker] = None
|
|
return None
|
|
|
|
|
|
SCANNER_REGISTRY.register(MLSignalScanner)
|