TradingAgents/tradingagents/agents/discovery/quantitative_scorer.py

179 lines
6.0 KiB
Python

import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import TYPE_CHECKING
from tradingagents.agents.discovery.indicators import (
calculate_momentum_score,
calculate_relative_strength_metrics,
calculate_support_resistance_metrics,
calculate_volume_metrics,
)
from tradingagents.agents.discovery.indicators.timeframe import (
calculate_timeframe_signals,
)
from tradingagents.agents.discovery.quantitative_cache import clear_run_cache
from tradingagents.agents.discovery.quantitative_models import QuantitativeMetrics
from tradingagents.config import QuantitativeWeightsConfig, get_settings
if TYPE_CHECKING:
from tradingagents.agents.discovery.models import TrendingStock
logger = logging.getLogger(__name__)
def calculate_unified_score(
momentum: float,
volume: float,
rs: float,
rr: float,
weights: QuantitativeWeightsConfig,
) -> float:
score = (
momentum * weights.momentum_weight
+ volume * weights.volume_weight
+ rs * weights.relative_strength_weight
+ rr * weights.risk_reward_weight
)
return max(0.0, min(1.0, score))
def calculate_single_stock_metrics(
ticker: str,
curr_date: str,
) -> QuantitativeMetrics | None:
try:
settings = get_settings()
weights = settings.quantitative_weights
momentum = calculate_momentum_score(ticker, curr_date)
volume = calculate_volume_metrics(ticker, curr_date)
rs = calculate_relative_strength_metrics(ticker, curr_date)
sr = calculate_support_resistance_metrics(ticker, curr_date)
timeframe = calculate_timeframe_signals(momentum, rs)
momentum_score = momentum.get("momentum_score", 0.5)
volume_score = volume.get("volume_score", 0.5)
rs_score = rs.get("relative_strength_score", 0.5)
rr_score = sr.get("risk_reward_score", 0.5)
unified_score = calculate_unified_score(
momentum=momentum_score,
volume=volume_score,
rs=rs_score,
rr=rr_score,
weights=weights,
)
return QuantitativeMetrics(
momentum_score=momentum_score,
volume_score=volume_score,
relative_strength_score=rs_score,
risk_reward_score=rr_score,
rsi=momentum.get("rsi"),
macd=momentum.get("macd"),
macd_signal=momentum.get("macd_signal"),
macd_histogram=momentum.get("macd_histogram"),
price_vs_sma50=momentum.get("price_vs_sma50"),
price_vs_sma200=momentum.get("price_vs_sma200"),
ema10_direction=momentum.get("ema10_direction"),
volume_ratio=volume.get("volume_ratio"),
volume_trend=volume.get("volume_trend"),
dollar_volume=volume.get("dollar_volume"),
rs_vs_spy_5d=rs.get("rs_vs_spy_5d"),
rs_vs_spy_20d=rs.get("rs_vs_spy_20d"),
rs_vs_spy_60d=rs.get("rs_vs_spy_60d"),
rs_vs_sector=rs.get("rs_vs_sector"),
sector_etf=rs.get("sector_etf"),
support_level=sr.get("support_level"),
resistance_level=sr.get("resistance_level"),
atr=sr.get("atr"),
suggested_stop=sr.get("suggested_stop"),
reward_target=sr.get("reward_target"),
risk_reward_ratio=sr.get("risk_reward_ratio"),
timeframe_alignment=timeframe.get("timeframe_alignment"),
short_term_signal=timeframe.get("short_term_signal"),
medium_term_signal=timeframe.get("medium_term_signal"),
long_term_signal=timeframe.get("long_term_signal"),
signal_strength=timeframe.get("signal_strength"),
quantitative_score=unified_score,
)
except Exception as e:
logger.warning(
"Failed to calculate quantitative metrics for %s: %s", ticker, str(e)
)
return None
def _normalize_score(score: float, max_score: float) -> float:
if max_score == 0:
return 0.0
return min(1.0, score / max_score)
def enhance_with_quantitative_scores(
stocks: list["TrendingStock"],
curr_date: str,
max_stocks: int = 50,
) -> list["TrendingStock"]:
if not stocks:
return []
settings = get_settings()
weights = settings.quantitative_weights
sorted_stocks = sorted(stocks, key=lambda s: s.score, reverse=True)
stocks_to_process = sorted_stocks[:max_stocks]
remaining_stocks = sorted_stocks[max_stocks:]
clear_run_cache()
results: dict[str, QuantitativeMetrics | None] = {}
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_ticker = {
executor.submit(
calculate_single_stock_metrics, stock.ticker, curr_date
): stock.ticker
for stock in stocks_to_process
}
for future in as_completed(future_to_ticker):
ticker = future_to_ticker[future]
try:
results[ticker] = future.result()
except Exception as e:
logger.warning("Quantitative scoring failed for %s: %s", ticker, str(e))
results[ticker] = None
max_news_score = max((s.score for s in stocks_to_process), default=1.0) or 1.0
for stock in stocks_to_process:
metrics = results.get(stock.ticker)
stock.quantitative_metrics = metrics
news_normalized = _normalize_score(stock.score, max_news_score)
if metrics is not None:
stock.conviction_score = (
weights.news_sentiment_weight * news_normalized
+ weights.quantitative_weight * metrics.quantitative_score
)
else:
stock.conviction_score = news_normalized * weights.news_sentiment_weight
for stock in remaining_stocks:
stock.quantitative_metrics = None
stock.conviction_score = None
enhanced_stocks = sorted(
stocks_to_process,
key=lambda s: s.conviction_score if s.conviction_score is not None else 0.0,
reverse=True,
)
clear_run_cache()
return enhanced_stocks + remaining_stocks