feat: add quantitative scoring with multi-timeframe analysis and CLI enhancements

Add quantitative scoring pipeline for discovery with technical indicator analysis:
- Momentum, volume, relative strength, and risk/reward scoring
- Support/resistance level detection
- Gap analysis for price momentum signals
- Configurable caching to reduce API calls

Implement multi-timeframe signal analysis:
- Short-term (5/20 day), medium-term (20/50 day), and long-term (50/200 day) signals
- Timeframe alignment detection (aligned_bullish, aligned_bearish, mixed, neutral)
- Signal strength calculation based on indicator agreement

Enhance CLI discovery display:
- Color-coded conviction scores (green/yellow/red thresholds)
- Signal column showing timeframe alignment status
- News mentions count column

Update tests to support new quantitative filtering configuration.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-03 17:04:18 -05:00
parent fb1a66f5a6
commit 1ba647ae4a
31 changed files with 4441 additions and 24 deletions

View File

@ -153,6 +153,35 @@ def select_event_filter() -> list[EventCategory] | None:
return choices
def _get_conviction_display(stock: TrendingStock) -> tuple[str, str]:
if stock.conviction_score is None:
return "-", "dim"
score = stock.conviction_score
if score >= 0.7:
return f"{score:.2f}", "bold green"
elif score >= 0.5:
return f"{score:.2f}", "yellow"
else:
return f"{score:.2f}", "red"
def _get_signal_display(stock: TrendingStock) -> str:
if stock.quantitative_metrics is None:
return "-"
alignment = stock.quantitative_metrics.timeframe_alignment
if alignment == "aligned_bullish":
return "[bold green]+++[/bold green]"
elif alignment == "aligned_bearish":
return "[bold red]---[/bold red]"
elif alignment == "mixed":
strength = stock.quantitative_metrics.signal_strength or 0.5
if strength > 0.5:
return "[yellow]++[/yellow]"
else:
return "[yellow]--[/yellow]"
return "[dim]~[/dim]"
def create_discovery_results_table(trending_stocks: list[TrendingStock]) -> Table:
table = Table(
show_header=True,
@ -164,11 +193,12 @@ def create_discovery_results_table(trending_stocks: list[TrendingStock]) -> Tabl
)
table.add_column("Rank", style="cyan", justify="center", width=6)
table.add_column("Ticker", style="bold yellow", justify="center", width=10)
table.add_column("Company", style="white", justify="left", width=25)
table.add_column("Score", style="green", justify="right", width=10)
table.add_column("Mentions", style="blue", justify="center", width=10)
table.add_column("Event Type", style="magenta", justify="center", width=18)
table.add_column("Ticker", style="bold yellow", justify="center", width=8)
table.add_column("Company", style="white", justify="left", width=20)
table.add_column("Conv.", justify="right", width=6)
table.add_column("Signal", justify="center", width=7)
table.add_column("News", style="blue", justify="right", width=6)
table.add_column("Event Type", style="magenta", justify="center", width=15)
for rank, stock in enumerate(trending_stocks, 1):
if rank <= 3:
@ -178,20 +208,54 @@ def create_discovery_results_table(trending_stocks: list[TrendingStock]) -> Tabl
rank_display = str(rank)
ticker_display = stock.ticker
conviction_text, conviction_style = _get_conviction_display(stock)
signal_display = _get_signal_display(stock)
table.add_row(
rank_display,
ticker_display,
stock.company_name[:25]
if len(stock.company_name) > 25
stock.company_name[:20]
if len(stock.company_name) > 20
else stock.company_name,
f"{stock.score:.2f}",
str(stock.mention_count),
stock.event_type.value.replace("_", " ").title(),
f"[{conviction_style}]{conviction_text}[/{conviction_style}]",
signal_display,
f"{stock.score:.1f}",
stock.event_type.value.replace("_", " ").title()[:15],
)
return table
def _format_timeframe_signals(stock: TrendingStock) -> str:
if stock.quantitative_metrics is None:
return "[dim]No quantitative data available[/dim]"
qm = stock.quantitative_metrics
short_color = (
"green"
if qm.short_term_signal == "bullish"
else "red"
if qm.short_term_signal == "bearish"
else "yellow"
)
med_color = (
"green"
if qm.medium_term_signal == "bullish"
else "red"
if qm.medium_term_signal == "bearish"
else "yellow"
)
long_color = (
"green"
if qm.long_term_signal == "bullish"
else "red"
if qm.long_term_signal == "bearish"
else "yellow"
)
return f"[{short_color}]Short: {(qm.short_term_signal or 'N/A').upper()}[/{short_color}] | [{med_color}]Med: {(qm.medium_term_signal or 'N/A').upper()}[/{med_color}] | [{long_color}]Long: {(qm.long_term_signal or 'N/A').upper()}[/{long_color}]"
def create_stock_detail_panel(stock: TrendingStock, rank: int) -> Panel:
sentiment_label = (
"positive"
@ -208,14 +272,64 @@ def create_stock_detail_panel(stock: TrendingStock, rank: int) -> Panel:
else "yellow"
)
conviction_text = (
f"{stock.conviction_score:.2f}" if stock.conviction_score is not None else "N/A"
)
conviction_color = (
"green"
if stock.conviction_score and stock.conviction_score >= 0.7
else "yellow"
if stock.conviction_score and stock.conviction_score >= 0.5
else "red"
)
content = f"""[bold]Rank #{rank}: {stock.ticker} - {stock.company_name}[/bold]
[cyan]Score:[/cyan] {stock.score:.2f}
[cyan]Conviction Score:[/cyan] [{conviction_color}]{conviction_text}[/{conviction_color}]
[cyan]News Score:[/cyan] {stock.score:.2f}
[cyan]Sentiment:[/cyan] [{sentiment_color}]{stock.sentiment:.2f} ({sentiment_label})[/{sentiment_color}]
[cyan]Sector:[/cyan] {stock.sector.value.replace("_", " ").title()}
[cyan]Event Type:[/cyan] {stock.event_type.value.replace("_", " ").title()}
[cyan]Mentions:[/cyan] {stock.mention_count}
[bold]Timeframe Signals:[/bold]
{_format_timeframe_signals(stock)}"""
if stock.quantitative_metrics is not None:
qm = stock.quantitative_metrics
alignment_color = (
"green"
if qm.timeframe_alignment == "aligned_bullish"
else "red"
if qm.timeframe_alignment == "aligned_bearish"
else "yellow"
)
content += f"""
[bold]Quantitative Metrics:[/bold]
[cyan]Timeframe Alignment:[/cyan] [{alignment_color}]{(qm.timeframe_alignment or 'N/A').replace('_', ' ').upper()}[/{alignment_color}]
[cyan]Momentum Score:[/cyan] {qm.momentum_score:.2f} [cyan]Volume Score:[/cyan] {qm.volume_score:.2f}
[cyan]Relative Strength:[/cyan] {qm.relative_strength_score:.2f} [cyan]Risk/Reward:[/cyan] {qm.risk_reward_score:.2f}"""
if qm.rsi is not None:
rsi_color = "green" if qm.rsi < 35 else "red" if qm.rsi > 65 else "yellow"
content += f"\n[cyan]RSI:[/cyan] [{rsi_color}]{qm.rsi:.1f}[/{rsi_color}]"
if qm.support_level is not None and qm.resistance_level is not None:
content += f" [cyan]Support:[/cyan] ${qm.support_level:.2f} [cyan]Resistance:[/cyan] ${qm.resistance_level:.2f}"
if qm.risk_reward_ratio is not None:
rr_color = (
"green"
if qm.risk_reward_ratio >= 2.0
else "yellow"
if qm.risk_reward_ratio >= 1.0
else "red"
)
content += f"\n[cyan]Risk/Reward Ratio:[/cyan] [{rr_color}]{qm.risk_reward_ratio:.2f}:1[/{rr_color}]"
content += f"""
[bold]News Summary:[/bold]
{stock.news_summary}

View File

@ -67,7 +67,9 @@ class TestDiscoverTrendingReturnsDiscoveryResult:
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
}
graph.db_enabled = False
result = graph.discover_trending()
@ -118,7 +120,9 @@ class TestSectorFilterParameter:
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
}
graph.db_enabled = False
request = DiscoveryRequest(
lookback_period="24h",
@ -162,7 +166,9 @@ class TestEventFilterParameter:
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
}
graph.db_enabled = False
request = DiscoveryRequest(
lookback_period="24h",

View File

@ -124,8 +124,9 @@ class TestResultsTableDisplay:
"Rank",
"Ticker",
"Company",
"Score",
"Mentions",
"Conv.",
"Signal",
"News",
"Event Type",
]
for expected in expected_columns:

View File

@ -89,7 +89,9 @@ class TestEndToEndDiscoveryFlow:
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
}
graph.db_enabled = False
request = DiscoveryRequest(lookback_period="24h")
result = graph.discover_trending(request)
@ -259,7 +261,9 @@ class TestNoTrendingStocksFound:
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
}
graph.db_enabled = False
result = graph.discover_trending()
@ -274,7 +278,16 @@ class TestAllStocksFilteredOutBySectorFilter:
def test_all_stocks_filtered_out_by_sector_filter(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = []
mock_bulk_news.return_value = [
NewsArticle(
title="Test article",
source="Test",
url="https://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
mock_extract.return_value = []
mock_scores.return_value = [
TrendingStock(
@ -311,7 +324,9 @@ class TestAllStocksFilteredOutBySectorFilter:
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
}
graph.db_enabled = False
request = DiscoveryRequest(
lookback_period="24h",
@ -330,7 +345,16 @@ class TestAllStocksFilteredOutByEventFilter:
def test_all_stocks_filtered_out_by_event_filter(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = []
mock_bulk_news.return_value = [
NewsArticle(
title="Test article",
source="Test",
url="https://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
mock_extract.return_value = []
mock_scores.return_value = [
TrendingStock(
@ -356,7 +380,9 @@ class TestAllStocksFilteredOutByEventFilter:
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
}
graph.db_enabled = False
request = DiscoveryRequest(
lookback_period="24h",
@ -375,7 +401,16 @@ class TestMultipleSectorsAndEventsFiltering:
def test_combined_sector_and_event_filtering(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = []
mock_bulk_news.return_value = [
NewsArticle(
title="Test article",
source="Test",
url="https://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
mock_extract.return_value = []
mock_scores.return_value = [
TrendingStock(
@ -423,7 +458,9 @@ class TestMultipleSectorsAndEventsFiltering:
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
}
graph.db_enabled = False
request = DiscoveryRequest(
lookback_period="24h",

View File

@ -0,0 +1,197 @@
from unittest.mock import MagicMock, patch
import pytest
class TestCalculateRsiScore:
def test_rsi_oversold_bullish_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_rsi_score,
)
score = calculate_rsi_score(25.0)
assert 0.8 <= score <= 1.0
score = calculate_rsi_score(30.0)
assert 0.7 <= score <= 0.9
def test_rsi_neutral_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_rsi_score,
)
score = calculate_rsi_score(50.0)
assert 0.5 <= score <= 0.7
score = calculate_rsi_score(40.0)
assert 0.5 <= score <= 0.7
def test_rsi_overbought_warning_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_rsi_score,
)
score = calculate_rsi_score(70.0)
assert 0.3 <= score <= 0.5
score = calculate_rsi_score(75.0)
assert 0.2 <= score <= 0.5
def test_rsi_extreme_overbought_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_rsi_score,
)
score = calculate_rsi_score(85.0)
assert 0.0 <= score <= 0.3
score = calculate_rsi_score(95.0)
assert 0.0 <= score <= 0.2
class TestCalculateMacdScore:
def test_macd_bullish_crossover_high_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_macd_score,
)
score = calculate_macd_score(macd=1.5, signal=1.0, histogram=0.5)
assert 0.7 <= score <= 1.0
def test_macd_bearish_crossover_low_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_macd_score,
)
score = calculate_macd_score(macd=-1.5, signal=-1.0, histogram=-0.5)
assert 0.0 <= score <= 0.4
def test_macd_neutral_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_macd_score,
)
score = calculate_macd_score(macd=0.1, signal=0.1, histogram=0.0)
assert 0.4 <= score <= 0.6
def test_macd_expanding_histogram_bullish(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_macd_score,
)
score = calculate_macd_score(macd=2.0, signal=1.5, histogram=0.8)
assert 0.6 <= score <= 1.0
class TestCalculateSmaScore:
def test_price_above_both_smas_high_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_sma_score,
)
score = calculate_sma_score(price=150.0, sma50=140.0, sma200=130.0)
assert 0.7 <= score <= 1.0
def test_price_below_both_smas_low_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_sma_score,
)
score = calculate_sma_score(price=120.0, sma50=140.0, sma200=150.0)
assert 0.0 <= score <= 0.4
def test_price_between_smas_moderate_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_sma_score,
)
score = calculate_sma_score(price=145.0, sma50=150.0, sma200=130.0)
assert 0.4 <= score <= 0.7
def test_golden_alignment_bonus(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_sma_score,
)
score = calculate_sma_score(price=160.0, sma50=150.0, sma200=140.0)
assert score >= 0.8
class TestCalculateEmaDirection:
def test_ema_upward_trend(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_ema_direction,
)
ema_values = [100.0, 102.0, 104.0, 106.0, 108.0]
direction = calculate_ema_direction(ema_values)
assert direction == "up"
def test_ema_downward_trend(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_ema_direction,
)
ema_values = [108.0, 106.0, 104.0, 102.0, 100.0]
direction = calculate_ema_direction(ema_values)
assert direction == "down"
def test_ema_flat_trend(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_ema_direction,
)
ema_values = [100.0, 100.1, 99.9, 100.0, 100.1]
direction = calculate_ema_direction(ema_values)
assert direction == "flat"
def test_ema_empty_list_returns_flat(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_ema_direction,
)
ema_values = []
direction = calculate_ema_direction(ema_values)
assert direction == "flat"
class TestCalculateMomentumScore:
@patch("tradingagents.agents.discovery.indicators.momentum._get_stock_stats_bulk")
def test_calculate_momentum_score_returns_dict(self, mock_get_stats):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_momentum_score,
)
mock_get_stats.side_effect = lambda symbol, indicator, date: {
"2024-01-15": "50.0" if indicator == "rsi" else "1.0",
"2024-01-14": "49.0" if indicator == "rsi" else "0.9",
"2024-01-13": "48.0" if indicator == "rsi" else "0.8",
"2024-01-12": "47.0" if indicator == "rsi" else "0.7",
"2024-01-11": "46.0" if indicator == "rsi" else "0.6",
}
result = calculate_momentum_score("AAPL", "2024-01-15")
assert isinstance(result, dict)
assert "rsi" in result
assert "macd" in result
assert "macd_signal" in result
assert "macd_histogram" in result
assert "price_vs_sma50" in result
assert "price_vs_sma200" in result
assert "ema10_direction" in result
assert "momentum_score" in result
assert 0.0 <= result["momentum_score"] <= 1.0
@patch("tradingagents.agents.discovery.indicators.momentum._get_stock_stats_bulk")
def test_calculate_momentum_score_handles_missing_data(self, mock_get_stats):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_momentum_score,
)
mock_get_stats.return_value = {}
result = calculate_momentum_score("INVALID", "2024-01-15")
assert isinstance(result, dict)
assert result["momentum_score"] == 0.5

View File

@ -0,0 +1,344 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
class TestDiscoverTrendingIntegration:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
@patch("tradingagents.graph.trading_graph.enhance_with_quantitative_scores")
def test_discover_trending_calls_quantitative_enhancement(
self, mock_enhance, mock_scores, mock_extract, mock_bulk_news
):
from tradingagents.agents.discovery.models import (
EventCategory,
Sector,
TrendingStock,
)
from tradingagents.dataflows.models import NewsArticle
from tradingagents.graph.trading_graph import TradingAgentsGraph
mock_bulk_news.return_value = [
NewsArticle(
title="Test Article",
source="Test",
url="http://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
mock_extract.return_value = []
mock_stock = TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=85.0,
mention_count=10,
sentiment=0.7,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Test",
source_articles=[],
)
mock_scores.return_value = [mock_stock]
mock_enhance.return_value = [mock_stock]
config = {
"llm_provider": "openai",
"quick_think_llm": "gpt-4o-mini",
"deep_think_llm": "gpt-4o",
"backend_url": "https://api.openai.com/v1",
"project_dir": "/tmp/test",
"database_enabled": False,
"enable_quantitative_filtering": True,
}
with (
patch("tradingagents.graph.trading_graph.ChatOpenAI"),
patch("tradingagents.graph.trading_graph.FinancialSituationMemory"),
patch("tradingagents.graph.trading_graph.GraphSetup"),
):
graph = TradingAgentsGraph(config=config)
result = graph.discover_trending()
mock_enhance.assert_called_once()
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
@patch("tradingagents.graph.trading_graph.enhance_with_quantitative_scores")
def test_discover_trending_skips_quantitative_when_disabled(
self, mock_enhance, mock_scores, mock_extract, mock_bulk_news
):
from tradingagents.agents.discovery.models import (
EventCategory,
Sector,
TrendingStock,
)
from tradingagents.dataflows.models import NewsArticle
from tradingagents.graph.trading_graph import TradingAgentsGraph
mock_bulk_news.return_value = [
NewsArticle(
title="Test Article",
source="Test",
url="http://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
mock_extract.return_value = []
mock_stock = TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=85.0,
mention_count=10,
sentiment=0.7,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Test",
source_articles=[],
)
mock_scores.return_value = [mock_stock]
config = {
"llm_provider": "openai",
"quick_think_llm": "gpt-4o-mini",
"deep_think_llm": "gpt-4o",
"backend_url": "https://api.openai.com/v1",
"project_dir": "/tmp/test",
"database_enabled": False,
"enable_quantitative_filtering": False,
}
with (
patch("tradingagents.graph.trading_graph.ChatOpenAI"),
patch("tradingagents.graph.trading_graph.FinancialSituationMemory"),
patch("tradingagents.graph.trading_graph.GraphSetup"),
):
graph = TradingAgentsGraph(config=config)
result = graph.discover_trending()
mock_enhance.assert_not_called()
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
@patch("tradingagents.graph.trading_graph.enhance_with_quantitative_scores")
def test_discover_trending_uses_config_max_stocks(
self, mock_enhance, mock_scores, mock_extract, mock_bulk_news
):
from tradingagents.agents.discovery.models import (
EventCategory,
Sector,
TrendingStock,
)
from tradingagents.dataflows.models import NewsArticle
from tradingagents.graph.trading_graph import TradingAgentsGraph
mock_bulk_news.return_value = [
NewsArticle(
title="Test Article",
source="Test",
url="http://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
mock_extract.return_value = []
mock_stocks = [
TrendingStock(
ticker=f"TICK{i}",
company_name=f"Company {i}",
score=100.0 - i,
mention_count=10,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.OTHER,
news_summary="Test",
source_articles=[],
)
for i in range(30)
]
mock_scores.return_value = mock_stocks
mock_enhance.return_value = mock_stocks
config = {
"llm_provider": "openai",
"quick_think_llm": "gpt-4o-mini",
"deep_think_llm": "gpt-4o",
"backend_url": "https://api.openai.com/v1",
"project_dir": "/tmp/test",
"database_enabled": False,
"enable_quantitative_filtering": True,
"quantitative_max_stocks": 25,
}
with (
patch("tradingagents.graph.trading_graph.ChatOpenAI"),
patch("tradingagents.graph.trading_graph.FinancialSituationMemory"),
patch("tradingagents.graph.trading_graph.GraphSetup"),
):
graph = TradingAgentsGraph(config=config)
result = graph.discover_trending()
call_args = mock_enhance.call_args
assert call_args[1].get("max_stocks", 50) == 25
class TestScorerConvictionSupport:
def test_calculate_trending_scores_preserves_original_score(self):
from tradingagents.agents.discovery.entity_extractor import EntityMention
from tradingagents.agents.discovery.models import (
EventCategory,
NewsArticle,
)
from tradingagents.agents.discovery.scorer import calculate_trending_scores
mentions = [
EntityMention(
company_name="Apple",
confidence=0.9,
sentiment=0.7,
event_type=EventCategory.EARNINGS,
context_snippet="Apple reports strong earnings",
article_id="article_0",
),
EntityMention(
company_name="Apple",
confidence=0.85,
sentiment=0.6,
event_type=EventCategory.EARNINGS,
context_snippet="Apple stock rises",
article_id="article_1",
),
]
articles = [
NewsArticle(
title="Article 1",
source="Test",
url="http://test.com/1",
published_at=datetime.now(),
content_snippet="Test content 1",
ticker_mentions=["AAPL"],
),
NewsArticle(
title="Article 2",
source="Test",
url="http://test.com/2",
published_at=datetime.now(),
content_snippet="Test content 2",
ticker_mentions=["AAPL"],
),
]
with patch(
"tradingagents.agents.discovery.scorer.resolve_ticker"
) as mock_resolve:
mock_resolve.return_value = "AAPL"
result = calculate_trending_scores(mentions, articles)
assert len(result) == 1
assert result[0].ticker == "AAPL"
assert result[0].score > 0
assert result[0].conviction_score is None
def test_trending_stock_supports_conviction_score(self):
from tradingagents.agents.discovery.models import (
EventCategory,
Sector,
TrendingStock,
)
from tradingagents.agents.discovery.quantitative_models import (
QuantitativeMetrics,
)
stock = TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=85.0,
mention_count=10,
sentiment=0.7,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Test",
source_articles=[],
)
assert stock.conviction_score is None
stock.conviction_score = 0.85
assert stock.conviction_score == 0.85
metrics = QuantitativeMetrics(
momentum_score=0.7,
volume_score=0.6,
relative_strength_score=0.65,
risk_reward_score=0.7,
quantitative_score=0.66,
)
stock.quantitative_metrics = metrics
assert stock.quantitative_metrics.quantitative_score == 0.66
class TestBackwardCompatibility:
def test_trending_stock_without_quantitative_fields(self):
from tradingagents.agents.discovery.models import (
EventCategory,
Sector,
TrendingStock,
)
stock = TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=85.0,
mention_count=10,
sentiment=0.7,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Test",
source_articles=[],
)
assert stock.quantitative_metrics is None
assert stock.conviction_score is None
stock_dict = stock.to_dict()
assert (
"quantitative_metrics" not in stock_dict
or stock_dict.get("quantitative_metrics") is None
)
assert (
"conviction_score" not in stock_dict
or stock_dict.get("conviction_score") is None
)
def test_trending_stock_from_dict_without_quantitative_fields(self):
from tradingagents.agents.discovery.models import TrendingStock
data = {
"ticker": "AAPL",
"company_name": "Apple Inc",
"score": 85.0,
"mention_count": 10,
"sentiment": 0.7,
"sector": "technology",
"event_type": "earnings",
"news_summary": "Test",
"source_articles": [],
}
stock = TrendingStock.from_dict(data)
assert stock.ticker == "AAPL"
assert stock.quantitative_metrics is None
assert stock.conviction_score is None

View File

@ -0,0 +1,101 @@
import time
from unittest.mock import MagicMock, patch
import pandas as pd
import pytest
class TestQuantitativeCache:
def test_cache_set_and_get(self):
from tradingagents.agents.discovery.quantitative_cache import (
clear_run_cache,
get_cached_price_data,
set_cached_price_data,
)
clear_run_cache()
df = pd.DataFrame(
{"Close": [100.0, 101.0, 102.0], "Volume": [1000, 1100, 1200]}
)
set_cached_price_data("AAPL", df)
cached = get_cached_price_data("AAPL")
assert cached is not None
assert len(cached) == 3
assert cached["Close"].tolist() == [100.0, 101.0, 102.0]
def test_cache_miss_returns_none(self):
from tradingagents.agents.discovery.quantitative_cache import (
clear_run_cache,
get_cached_price_data,
)
clear_run_cache()
result = get_cached_price_data("NONEXISTENT")
assert result is None
def test_cache_clear(self):
from tradingagents.agents.discovery.quantitative_cache import (
clear_run_cache,
get_cached_price_data,
set_cached_price_data,
)
clear_run_cache()
df = pd.DataFrame({"Close": [100.0]})
set_cached_price_data("AAPL", df)
assert get_cached_price_data("AAPL") is not None
clear_run_cache()
assert get_cached_price_data("AAPL") is None
def test_cache_max_size_enforcement(self):
from tradingagents.agents.discovery.quantitative_cache import (
MAX_CACHE_SIZE,
clear_run_cache,
get_cached_price_data,
set_cached_price_data,
)
clear_run_cache()
for i in range(MAX_CACHE_SIZE + 10):
ticker = f"TICKER{i}"
df = pd.DataFrame({"Close": [float(i)]})
set_cached_price_data(ticker, df)
cached_count = 0
for i in range(MAX_CACHE_SIZE + 10):
ticker = f"TICKER{i}"
if get_cached_price_data(ticker) is not None:
cached_count += 1
assert cached_count <= MAX_CACHE_SIZE
class TestCacheTTLConstants:
def test_default_ttl_hours_contains_quant_entries(self):
from tradingagents.database.services.market_data import DEFAULT_TTL_HOURS
assert "quant_indicators" in DEFAULT_TTL_HOURS
assert "volume_analysis" in DEFAULT_TTL_HOURS
assert "relative_strength" in DEFAULT_TTL_HOURS
assert "support_resistance" in DEFAULT_TTL_HOURS
assert "risk_reward" in DEFAULT_TTL_HOURS
def test_quant_ttl_values(self):
from tradingagents.database.services.market_data import DEFAULT_TTL_HOURS
assert DEFAULT_TTL_HOURS["quant_indicators"] == 1
assert DEFAULT_TTL_HOURS["volume_analysis"] == 1
assert DEFAULT_TTL_HOURS["relative_strength"] == 4
assert DEFAULT_TTL_HOURS["support_resistance"] == 1
assert DEFAULT_TTL_HOURS["risk_reward"] == 1

View File

@ -0,0 +1,124 @@
import os
from unittest.mock import patch
import pytest
from pydantic import ValidationError
from tradingagents.config import (
QuantitativeWeightsConfig,
TradingAgentsSettings,
get_settings,
reset_settings,
)
class TestQuantitativeWeightsConfigDefaults:
def test_default_weight_values_are_set_correctly(self):
config = QuantitativeWeightsConfig()
assert config.news_sentiment_weight == 0.50
assert config.quantitative_weight == 0.50
assert config.momentum_weight == 0.30
assert config.volume_weight == 0.25
assert config.relative_strength_weight == 0.25
assert config.risk_reward_weight == 0.20
class TestQuantitativeWeightsConfigValidation:
def test_top_level_weights_sum_to_one(self):
config = QuantitativeWeightsConfig(
news_sentiment_weight=0.60,
quantitative_weight=0.40,
)
assert (
config.news_sentiment_weight + config.quantitative_weight
== pytest.approx(1.0)
)
def test_sub_weights_sum_to_one(self):
config = QuantitativeWeightsConfig()
sub_weights_sum = (
config.momentum_weight
+ config.volume_weight
+ config.relative_strength_weight
+ config.risk_reward_weight
)
assert sub_weights_sum == pytest.approx(1.0)
def test_top_level_weights_validation_rejects_invalid_sum(self):
with pytest.raises(ValidationError) as exc_info:
QuantitativeWeightsConfig(
news_sentiment_weight=0.60,
quantitative_weight=0.60,
)
assert "sum to 1.0" in str(exc_info.value).lower()
def test_sub_weights_validation_rejects_invalid_sum(self):
with pytest.raises(ValidationError) as exc_info:
QuantitativeWeightsConfig(
momentum_weight=0.50,
volume_weight=0.50,
relative_strength_weight=0.50,
risk_reward_weight=0.50,
)
assert "sum to 1.0" in str(exc_info.value).lower()
class TestQuantitativeWeightsConfigEnvOverride:
def setup_method(self):
reset_settings()
def teardown_method(self):
reset_settings()
def test_environment_variable_override_functionality(self):
env_vars = {
"TRADINGAGENTS_QUANTITATIVE_WEIGHTS__NEWS_SENTIMENT_WEIGHT": "0.70",
"TRADINGAGENTS_QUANTITATIVE_WEIGHTS__QUANTITATIVE_WEIGHT": "0.30",
}
with patch.dict(os.environ, env_vars, clear=False):
reset_settings()
settings = get_settings()
assert settings.quantitative_weights.news_sentiment_weight == pytest.approx(
0.70
)
assert settings.quantitative_weights.quantitative_weight == pytest.approx(
0.30
)
class TestQuantitativeSettingsIntegration:
def setup_method(self):
reset_settings()
def teardown_method(self):
reset_settings()
def test_quantitative_settings_in_trading_agents_settings(self):
settings = TradingAgentsSettings()
assert hasattr(settings, "quantitative_weights")
assert isinstance(settings.quantitative_weights, QuantitativeWeightsConfig)
assert hasattr(settings, "quantitative_max_stocks")
assert hasattr(settings, "quantitative_cache_ttl_intraday")
assert hasattr(settings, "quantitative_cache_ttl_relative_strength")
assert hasattr(settings, "min_dollar_volume")
def test_quantitative_settings_default_values(self):
settings = TradingAgentsSettings()
assert settings.quantitative_max_stocks == 50
assert settings.quantitative_cache_ttl_intraday == 1
assert settings.quantitative_cache_ttl_relative_strength == 4
assert settings.min_dollar_volume == 1_000_000.0
def test_quantitative_max_stocks_bounds(self):
settings = TradingAgentsSettings(quantitative_max_stocks=75)
assert settings.quantitative_max_stocks == 75
with pytest.raises(ValidationError):
TradingAgentsSettings(quantitative_max_stocks=5)
with pytest.raises(ValidationError):
TradingAgentsSettings(quantitative_max_stocks=150)

View File

@ -0,0 +1,465 @@
import threading
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from unittest.mock import MagicMock, patch
import pandas as pd
import pytest
from tradingagents.agents.discovery.models import (
EventCategory,
NewsArticle,
Sector,
TrendingStock,
)
from tradingagents.agents.discovery.quantitative_models import QuantitativeMetrics
from tradingagents.config import QuantitativeWeightsConfig
class TestFullPipelineIntegration:
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_volume_metrics"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_relative_strength_metrics"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_support_resistance_metrics"
)
def test_full_pipeline_news_to_quantitative_enhancement(
self, mock_sr, mock_rs, mock_vol, mock_mom
):
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
mock_mom.return_value = {
"rsi": 45.0,
"macd": 0.5,
"macd_signal": 0.4,
"macd_histogram": 0.1,
"price_vs_sma50": 5.0,
"price_vs_sma200": 10.0,
"ema10_direction": "up",
"momentum_score": 0.7,
}
mock_vol.return_value = {
"volume_ratio": 1.8,
"volume_trend": "increasing",
"dollar_volume": 25000000.0,
"volume_score": 0.75,
}
mock_rs.return_value = {
"rs_vs_spy_5d": 3.0,
"rs_vs_spy_20d": 5.0,
"rs_vs_spy_60d": 8.0,
"rs_vs_sector": 2.5,
"sector_etf": "XLK",
"relative_strength_score": 0.8,
}
mock_sr.return_value = {
"support_level": 145.0,
"resistance_level": 175.0,
"atr": 3.0,
"suggested_stop": 140.5,
"reward_target": 175.0,
"risk_reward_ratio": 2.5,
"risk_reward_score": 0.85,
}
articles = [
NewsArticle(
title="Apple Reports Strong Quarter",
source="Reuters",
url="http://example.com/1",
published_at=datetime.now(),
content_snippet="Apple beats expectations",
ticker_mentions=["AAPL"],
)
]
stocks = [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=85.0,
mention_count=15,
sentiment=0.75,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Apple reports strong quarterly results",
source_articles=articles,
),
TrendingStock(
ticker="MSFT",
company_name="Microsoft Corp",
score=72.0,
mention_count=10,
sentiment=0.6,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.PRODUCT_LAUNCH,
news_summary="Microsoft launches new product",
source_articles=[],
),
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15")
assert len(result) == 2
for stock in result:
if stock.quantitative_metrics is not None:
assert stock.conviction_score is not None
assert 0.0 <= stock.conviction_score <= 1.0
assert stock.quantitative_metrics.rsi == 45.0
assert stock.quantitative_metrics.sector_etf == "XLK"
class TestDelistedStockHandling:
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score"
)
def test_stock_with_no_trading_data_returns_none(self, mock_mom):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_single_stock_metrics,
)
mock_mom.side_effect = Exception("No price data available for delisted stock")
result = calculate_single_stock_metrics("DELIST", "2024-01-15")
assert result is None
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_volume_metrics"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_relative_strength_metrics"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_support_resistance_metrics"
)
def test_enhance_continues_after_delisted_stock_failure(
self, mock_sr, mock_rs, mock_vol, mock_mom
):
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
def mom_side_effect(ticker, date):
if ticker == "DELIST":
raise Exception("No data for delisted stock")
return {"momentum_score": 0.6, "rsi": 50.0}
mock_mom.side_effect = mom_side_effect
mock_vol.return_value = {"volume_score": 0.5}
mock_rs.return_value = {"relative_strength_score": 0.5, "sector_etf": "SPY"}
mock_sr.return_value = {"risk_reward_score": 0.5}
stocks = [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=90.0,
mention_count=10,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.OTHER,
news_summary="Test",
source_articles=[],
),
TrendingStock(
ticker="DELIST",
company_name="Delisted Corp",
score=80.0,
mention_count=8,
sentiment=0.4,
sector=Sector.OTHER,
event_type=EventCategory.OTHER,
news_summary="Test",
source_articles=[],
),
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15")
assert len(result) == 2
aapl = next((s for s in result if s.ticker == "AAPL"), None)
delist = next((s for s in result if s.ticker == "DELIST"), None)
assert aapl is not None
assert aapl.quantitative_metrics is not None
assert delist is not None
assert delist.quantitative_metrics is None
class TestWeightConfigurationEdgeCases:
def test_weights_near_boundary_accepted(self):
config = QuantitativeWeightsConfig(
news_sentiment_weight=0.501,
quantitative_weight=0.499,
)
total = config.news_sentiment_weight + config.quantitative_weight
assert abs(total - 1.0) < 0.01
def test_sub_weights_custom_values_accepted(self):
config = QuantitativeWeightsConfig(
momentum_weight=0.40,
volume_weight=0.20,
relative_strength_weight=0.20,
risk_reward_weight=0.20,
)
sub_total = (
config.momentum_weight
+ config.volume_weight
+ config.relative_strength_weight
+ config.risk_reward_weight
)
assert abs(sub_total - 1.0) < 0.01
class TestCacheConcurrentAccess:
def test_cache_thread_safety_under_concurrent_writes(self):
from tradingagents.agents.discovery.quantitative_cache import (
clear_run_cache,
get_cached_price_data,
set_cached_price_data,
)
clear_run_cache()
errors = []
results = {}
def write_to_cache(ticker_num):
try:
ticker = f"TICK{ticker_num}"
df = pd.DataFrame({"Close": [float(ticker_num)]})
set_cached_price_data(ticker, df)
cached = get_cached_price_data(ticker)
if cached is not None:
results[ticker] = cached["Close"].iloc[0]
except Exception as e:
errors.append(str(e))
with ThreadPoolExecutor(max_workers=20) as executor:
futures = [executor.submit(write_to_cache, i) for i in range(50)]
for f in futures:
f.result()
assert len(errors) == 0
clear_run_cache()
class TestConvictionScoreRanking:
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics"
)
def test_conviction_score_ranking_accuracy(self, mock_calc):
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
def side_effect(ticker, date):
scores = {
"HIGH_BOTH": (0.9, 0.9),
"HIGH_NEWS_LOW_QUANT": (0.3, 0.9),
"LOW_NEWS_HIGH_QUANT": (0.9, 0.3),
"LOW_BOTH": (0.3, 0.3),
}
quant, _ = scores.get(ticker, (0.5, 0.5))
return QuantitativeMetrics(
momentum_score=quant,
volume_score=quant,
relative_strength_score=quant,
risk_reward_score=quant,
quantitative_score=quant,
)
mock_calc.side_effect = side_effect
stocks = [
TrendingStock(
ticker="LOW_BOTH",
company_name="Low Both",
score=30.0,
mention_count=5,
sentiment=0.3,
sector=Sector.OTHER,
event_type=EventCategory.OTHER,
news_summary="Test",
source_articles=[],
),
TrendingStock(
ticker="HIGH_BOTH",
company_name="High Both",
score=90.0,
mention_count=15,
sentiment=0.9,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Test",
source_articles=[],
),
TrendingStock(
ticker="HIGH_NEWS_LOW_QUANT",
company_name="High News Low Quant",
score=90.0,
mention_count=15,
sentiment=0.9,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.OTHER,
news_summary="Test",
source_articles=[],
),
TrendingStock(
ticker="LOW_NEWS_HIGH_QUANT",
company_name="Low News High Quant",
score=30.0,
mention_count=5,
sentiment=0.3,
sector=Sector.OTHER,
event_type=EventCategory.OTHER,
news_summary="Test",
source_articles=[],
),
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15")
assert result[0].ticker == "HIGH_BOTH"
high_both = next(s for s in result if s.ticker == "HIGH_BOTH")
low_both = next(s for s in result if s.ticker == "LOW_BOTH")
assert high_both.conviction_score > low_both.conviction_score
class TestTrendingStockSerializationWithQuantitativeMetrics:
def test_trending_stock_with_quantitative_metrics_roundtrip(self):
articles = [
NewsArticle(
title="Test Article",
source="Test Source",
url="http://test.com",
published_at=datetime(2024, 1, 15, 10, 30, 0),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
metrics = QuantitativeMetrics(
momentum_score=0.75,
volume_score=0.65,
relative_strength_score=0.80,
risk_reward_score=0.70,
rsi=42.5,
macd=0.35,
macd_signal=0.28,
macd_histogram=0.07,
price_vs_sma50=4.5,
price_vs_sma200=9.2,
ema10_direction="up",
volume_ratio=1.65,
volume_trend="increasing",
dollar_volume=18500000.0,
rs_vs_spy_5d=2.8,
rs_vs_spy_20d=4.2,
rs_vs_spy_60d=7.5,
rs_vs_sector=2.1,
sector_etf="XLK",
support_level=148.50,
resistance_level=165.75,
atr=2.85,
suggested_stop=144.22,
reward_target=165.75,
risk_reward_ratio=2.65,
quantitative_score=0.725,
)
original = TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=87.5,
mention_count=12,
sentiment=0.72,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Apple reports strong quarterly earnings",
source_articles=articles,
quantitative_metrics=metrics,
conviction_score=0.815,
)
data = original.to_dict()
restored = TrendingStock.from_dict(data)
assert restored.ticker == original.ticker
assert restored.score == original.score
assert restored.conviction_score == original.conviction_score
assert restored.quantitative_metrics is not None
assert restored.quantitative_metrics.momentum_score == 0.75
assert restored.quantitative_metrics.rsi == 42.5
assert restored.quantitative_metrics.sector_etf == "XLK"
assert restored.quantitative_metrics.quantitative_score == 0.725
class TestModuleIntegration:
def test_unified_score_combines_all_indicators_correctly(self):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_unified_score,
)
weights = QuantitativeWeightsConfig()
score = calculate_unified_score(
momentum=0.8,
volume=0.6,
rs=0.7,
rr=0.9,
weights=weights,
)
expected = (
0.8 * weights.momentum_weight
+ 0.6 * weights.volume_weight
+ 0.7 * weights.relative_strength_weight
+ 0.9 * weights.risk_reward_weight
)
assert abs(score - expected) < 0.001
def test_unified_score_clamped_to_valid_range(self):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_unified_score,
)
weights = QuantitativeWeightsConfig()
score = calculate_unified_score(
momentum=1.5,
volume=1.5,
rs=1.5,
rr=1.5,
weights=weights,
)
assert score == 1.0
score_low = calculate_unified_score(
momentum=-0.5,
volume=-0.5,
rs=-0.5,
rr=-0.5,
weights=weights,
)
assert score_low == 0.0

View File

@ -0,0 +1,267 @@
import pytest
from pydantic import ValidationError
from tradingagents.agents.discovery.quantitative_models import QuantitativeMetrics
class TestQuantitativeMetricsInstantiation:
def test_model_instantiation_with_valid_data(self):
metrics = QuantitativeMetrics(
momentum_score=0.75,
volume_score=0.60,
relative_strength_score=0.80,
risk_reward_score=0.70,
rsi=45.5,
macd=0.25,
macd_signal=0.20,
macd_histogram=0.05,
price_vs_sma50=5.2,
price_vs_sma200=12.3,
ema10_direction="up",
volume_ratio=1.8,
volume_trend="increasing",
dollar_volume=15_000_000.0,
rs_vs_spy_5d=2.5,
rs_vs_spy_20d=5.0,
rs_vs_spy_60d=8.2,
rs_vs_sector=3.1,
sector_etf="XLK",
support_level=145.50,
resistance_level=162.30,
atr=3.25,
suggested_stop=140.62,
reward_target=162.30,
risk_reward_ratio=2.8,
quantitative_score=0.72,
)
assert metrics.momentum_score == 0.75
assert metrics.volume_score == 0.60
assert metrics.relative_strength_score == 0.80
assert metrics.risk_reward_score == 0.70
assert metrics.rsi == 45.5
assert metrics.macd == 0.25
assert metrics.ema10_direction == "up"
assert metrics.sector_etf == "XLK"
assert metrics.quantitative_score == 0.72
class TestQuantitativeMetricsScoreValidation:
def test_score_validation_accepts_valid_range(self):
metrics = QuantitativeMetrics(
momentum_score=0.0,
volume_score=0.5,
relative_strength_score=1.0,
risk_reward_score=0.99,
rsi=50.0,
macd=0.0,
macd_signal=0.0,
macd_histogram=0.0,
price_vs_sma50=0.0,
price_vs_sma200=0.0,
ema10_direction="flat",
volume_ratio=1.0,
volume_trend="flat",
dollar_volume=1_000_000.0,
rs_vs_spy_5d=0.0,
rs_vs_spy_20d=0.0,
rs_vs_spy_60d=0.0,
rs_vs_sector=0.0,
sector_etf="SPY",
support_level=100.0,
resistance_level=110.0,
atr=2.0,
suggested_stop=97.0,
reward_target=110.0,
risk_reward_ratio=3.33,
quantitative_score=0.5,
)
assert metrics.momentum_score == 0.0
assert metrics.relative_strength_score == 1.0
assert metrics.quantitative_score == 0.5
def test_score_validation_rejects_negative_score(self):
with pytest.raises(ValidationError) as exc_info:
QuantitativeMetrics(
momentum_score=-0.1,
volume_score=0.5,
relative_strength_score=0.5,
risk_reward_score=0.5,
rsi=50.0,
macd=0.0,
macd_signal=0.0,
macd_histogram=0.0,
price_vs_sma50=0.0,
price_vs_sma200=0.0,
ema10_direction="flat",
volume_ratio=1.0,
volume_trend="flat",
dollar_volume=1_000_000.0,
rs_vs_spy_5d=0.0,
rs_vs_spy_20d=0.0,
rs_vs_spy_60d=0.0,
rs_vs_sector=0.0,
sector_etf="SPY",
support_level=100.0,
resistance_level=110.0,
atr=2.0,
suggested_stop=97.0,
reward_target=110.0,
risk_reward_ratio=3.33,
quantitative_score=0.5,
)
assert "momentum_score" in str(exc_info.value)
def test_score_validation_rejects_above_one(self):
with pytest.raises(ValidationError) as exc_info:
QuantitativeMetrics(
momentum_score=0.5,
volume_score=1.5,
relative_strength_score=0.5,
risk_reward_score=0.5,
rsi=50.0,
macd=0.0,
macd_signal=0.0,
macd_histogram=0.0,
price_vs_sma50=0.0,
price_vs_sma200=0.0,
ema10_direction="flat",
volume_ratio=1.0,
volume_trend="flat",
dollar_volume=1_000_000.0,
rs_vs_spy_5d=0.0,
rs_vs_spy_20d=0.0,
rs_vs_spy_60d=0.0,
rs_vs_sector=0.0,
sector_etf="SPY",
support_level=100.0,
resistance_level=110.0,
atr=2.0,
suggested_stop=97.0,
reward_target=110.0,
risk_reward_ratio=3.33,
quantitative_score=0.5,
)
assert "volume_score" in str(exc_info.value)
class TestQuantitativeMetricsSerialization:
def test_to_dict_and_from_dict_roundtrip(self):
original = QuantitativeMetrics(
momentum_score=0.75,
volume_score=0.60,
relative_strength_score=0.80,
risk_reward_score=0.70,
rsi=45.5,
macd=0.25,
macd_signal=0.20,
macd_histogram=0.05,
price_vs_sma50=5.2,
price_vs_sma200=12.3,
ema10_direction="up",
volume_ratio=1.8,
volume_trend="increasing",
dollar_volume=15_000_000.0,
rs_vs_spy_5d=2.5,
rs_vs_spy_20d=5.0,
rs_vs_spy_60d=8.2,
rs_vs_sector=3.1,
sector_etf="XLK",
support_level=145.50,
resistance_level=162.30,
atr=3.25,
suggested_stop=140.62,
reward_target=162.30,
risk_reward_ratio=2.8,
quantitative_score=0.72,
)
data = original.to_dict()
restored = QuantitativeMetrics.from_dict(data)
assert restored.momentum_score == original.momentum_score
assert restored.volume_score == original.volume_score
assert restored.relative_strength_score == original.relative_strength_score
assert restored.risk_reward_score == original.risk_reward_score
assert restored.rsi == original.rsi
assert restored.macd == original.macd
assert restored.ema10_direction == original.ema10_direction
assert restored.sector_etf == original.sector_etf
assert restored.quantitative_score == original.quantitative_score
class TestQuantitativeMetricsOptionalFields:
def test_optional_field_handling_with_none_defaults(self):
metrics = QuantitativeMetrics(
momentum_score=0.5,
volume_score=0.5,
relative_strength_score=0.5,
risk_reward_score=0.5,
rsi=None,
macd=None,
macd_signal=None,
macd_histogram=None,
price_vs_sma50=None,
price_vs_sma200=None,
ema10_direction=None,
volume_ratio=None,
volume_trend=None,
dollar_volume=None,
rs_vs_spy_5d=None,
rs_vs_spy_20d=None,
rs_vs_spy_60d=None,
rs_vs_sector=None,
sector_etf=None,
support_level=None,
resistance_level=None,
atr=None,
suggested_stop=None,
reward_target=None,
risk_reward_ratio=None,
quantitative_score=0.5,
)
assert metrics.rsi is None
assert metrics.macd is None
assert metrics.ema10_direction is None
assert metrics.sector_etf is None
assert metrics.momentum_score == 0.5
assert metrics.quantitative_score == 0.5
def test_serialization_with_none_values(self):
original = QuantitativeMetrics(
momentum_score=0.5,
volume_score=0.5,
relative_strength_score=0.5,
risk_reward_score=0.5,
rsi=None,
macd=None,
macd_signal=None,
macd_histogram=None,
price_vs_sma50=None,
price_vs_sma200=None,
ema10_direction=None,
volume_ratio=None,
volume_trend=None,
dollar_volume=None,
rs_vs_spy_5d=None,
rs_vs_spy_20d=None,
rs_vs_spy_60d=None,
rs_vs_sector=None,
sector_etf=None,
support_level=None,
resistance_level=None,
atr=None,
suggested_stop=None,
reward_target=None,
risk_reward_ratio=None,
quantitative_score=0.5,
)
data = original.to_dict()
restored = QuantitativeMetrics.from_dict(data)
assert restored.rsi is None
assert restored.ema10_direction is None
assert restored.momentum_score == 0.5

View File

@ -0,0 +1,312 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
class TestCalculateSingleStockMetrics:
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_volume_metrics"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_relative_strength_metrics"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_support_resistance_metrics"
)
def test_calculate_single_stock_metrics_success(
self, mock_sr, mock_rs, mock_vol, mock_mom
):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_single_stock_metrics,
)
mock_mom.return_value = {
"rsi": 45.0,
"macd": 0.5,
"macd_signal": 0.4,
"macd_histogram": 0.1,
"price_vs_sma50": 2.5,
"price_vs_sma200": 5.0,
"ema10_direction": "up",
"momentum_score": 0.65,
}
mock_vol.return_value = {
"volume_ratio": 1.5,
"volume_trend": "increasing",
"dollar_volume": 5000000.0,
"volume_score": 0.7,
}
mock_rs.return_value = {
"rs_vs_spy_5d": 2.0,
"rs_vs_spy_20d": 3.5,
"rs_vs_spy_60d": 5.0,
"rs_vs_sector": 2.5,
"sector_etf": "XLK",
"relative_strength_score": 0.6,
}
mock_sr.return_value = {
"support_level": 150.0,
"resistance_level": 180.0,
"atr": 3.5,
"suggested_stop": 145.0,
"reward_target": 180.0,
"risk_reward_ratio": 2.5,
"risk_reward_score": 0.75,
}
result = calculate_single_stock_metrics("AAPL", "2024-01-15")
assert result is not None
assert result.momentum_score == 0.65
assert result.volume_score == 0.7
assert result.relative_strength_score == 0.6
assert result.risk_reward_score == 0.75
assert 0.0 <= result.quantitative_score <= 1.0
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score"
)
def test_calculate_single_stock_metrics_failure(self, mock_mom):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_single_stock_metrics,
)
mock_mom.side_effect = Exception("API failure")
result = calculate_single_stock_metrics("INVALID", "2024-01-15")
assert result is None
class TestCalculateUnifiedScore:
def test_unified_score_basic(self):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_unified_score,
)
from tradingagents.config import QuantitativeWeightsConfig
weights = QuantitativeWeightsConfig()
score = calculate_unified_score(
momentum=0.8,
volume=0.6,
rs=0.7,
rr=0.5,
weights=weights,
)
expected = 0.8 * 0.30 + 0.6 * 0.25 + 0.7 * 0.25 + 0.5 * 0.20
assert abs(score - expected) < 0.001
def test_unified_score_all_ones(self):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_unified_score,
)
from tradingagents.config import QuantitativeWeightsConfig
weights = QuantitativeWeightsConfig()
score = calculate_unified_score(
momentum=1.0,
volume=1.0,
rs=1.0,
rr=1.0,
weights=weights,
)
assert abs(score - 1.0) < 0.001
def test_unified_score_all_zeros(self):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_unified_score,
)
from tradingagents.config import QuantitativeWeightsConfig
weights = QuantitativeWeightsConfig()
score = calculate_unified_score(
momentum=0.0,
volume=0.0,
rs=0.0,
rr=0.0,
weights=weights,
)
assert abs(score - 0.0) < 0.001
class TestEnhanceWithQuantitativeScores:
def _create_mock_trending_stock(self, ticker: str, score: float):
from tradingagents.agents.discovery.models import (
EventCategory,
NewsArticle,
Sector,
TrendingStock,
)
return TrendingStock(
ticker=ticker,
company_name=f"{ticker} Inc",
score=score,
mention_count=10,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Test summary",
source_articles=[],
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics"
)
def test_enhance_with_quantitative_scores_basic(self, mock_calc):
from tradingagents.agents.discovery.quantitative_models import (
QuantitativeMetrics,
)
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
mock_calc.return_value = QuantitativeMetrics(
momentum_score=0.7,
volume_score=0.6,
relative_strength_score=0.65,
risk_reward_score=0.7,
quantitative_score=0.66,
)
stocks = [
self._create_mock_trending_stock("AAPL", 90.0),
self._create_mock_trending_stock("MSFT", 80.0),
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15")
assert len(result) == 2
for stock in result:
assert stock.quantitative_metrics is not None
assert stock.conviction_score is not None
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics"
)
def test_enhance_caps_at_max_stocks(self, mock_calc):
from tradingagents.agents.discovery.quantitative_models import (
QuantitativeMetrics,
)
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
mock_calc.return_value = QuantitativeMetrics(
momentum_score=0.5,
volume_score=0.5,
relative_strength_score=0.5,
risk_reward_score=0.5,
quantitative_score=0.5,
)
stocks = [
self._create_mock_trending_stock(f"TICK{i}", 100.0 - i) for i in range(60)
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15", max_stocks=10)
assert mock_calc.call_count == 10
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics"
)
def test_enhance_handles_partial_failures(self, mock_calc):
from tradingagents.agents.discovery.quantitative_models import (
QuantitativeMetrics,
)
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
def side_effect(ticker, date):
if ticker == "FAIL":
return None
return QuantitativeMetrics(
momentum_score=0.5,
volume_score=0.5,
relative_strength_score=0.5,
risk_reward_score=0.5,
quantitative_score=0.5,
)
mock_calc.side_effect = side_effect
stocks = [
self._create_mock_trending_stock("AAPL", 90.0),
self._create_mock_trending_stock("FAIL", 85.0),
self._create_mock_trending_stock("MSFT", 80.0),
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15")
assert len(result) == 3
successful = [s for s in result if s.quantitative_metrics is not None]
assert len(successful) == 2
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics"
)
def test_enhance_sorts_by_conviction_score(self, mock_calc):
from tradingagents.agents.discovery.quantitative_models import (
QuantitativeMetrics,
)
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
def side_effect(ticker, date):
scores = {
"LOW": 0.3,
"MED": 0.5,
"HIGH": 0.9,
}
quant_score = scores.get(ticker, 0.5)
return QuantitativeMetrics(
momentum_score=quant_score,
volume_score=quant_score,
relative_strength_score=quant_score,
risk_reward_score=quant_score,
quantitative_score=quant_score,
)
mock_calc.side_effect = side_effect
stocks = [
self._create_mock_trending_stock("LOW", 60.0),
self._create_mock_trending_stock("HIGH", 40.0),
self._create_mock_trending_stock("MED", 50.0),
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15")
assert (
result[0].quantitative_metrics.quantitative_score
>= result[1].quantitative_metrics.quantitative_score
)
assert (
result[1].quantitative_metrics.quantitative_score
>= result[2].quantitative_metrics.quantitative_score
)
class TestErrorHandling:
def test_empty_stock_list(self):
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
result = enhance_with_quantitative_scores([], "2024-01-15")
assert result == []

View File

@ -0,0 +1,190 @@
from unittest.mock import MagicMock, patch
import pytest
class TestCalculateReturn:
def test_calculate_positive_return(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_return,
)
prices = [100.0, 102.0, 105.0, 108.0, 110.0]
ret = calculate_return(prices, days=5)
assert ret == pytest.approx(10.0, rel=0.01)
def test_calculate_negative_return(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_return,
)
prices = [100.0, 98.0, 95.0, 92.0, 90.0]
ret = calculate_return(prices, days=5)
assert ret == pytest.approx(-10.0, rel=0.01)
def test_calculate_return_insufficient_data(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_return,
)
prices = [100.0, 110.0]
ret = calculate_return(prices, days=5)
assert ret == pytest.approx(10.0, rel=0.01)
def test_calculate_return_empty_list(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_return,
)
prices = []
ret = calculate_return(prices, days=5)
assert ret == 0.0
class TestCalculateRelativeStrength:
def test_positive_outperformance(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_relative_strength,
)
rs = calculate_relative_strength(stock_return=15.0, benchmark_return=10.0)
assert rs == 5.0
def test_negative_outperformance(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_relative_strength,
)
rs = calculate_relative_strength(stock_return=5.0, benchmark_return=10.0)
assert rs == -5.0
def test_equal_returns(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_relative_strength,
)
rs = calculate_relative_strength(stock_return=10.0, benchmark_return=10.0)
assert rs == 0.0
class TestSectorEtfMap:
def test_sector_etf_map_contains_expected_sectors(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
SECTOR_ETF_MAP,
)
assert "technology" in SECTOR_ETF_MAP
assert "finance" in SECTOR_ETF_MAP
assert "healthcare" in SECTOR_ETF_MAP
assert "energy" in SECTOR_ETF_MAP
assert "consumer_goods" in SECTOR_ETF_MAP
assert "industrials" in SECTOR_ETF_MAP
assert "other" in SECTOR_ETF_MAP
def test_sector_etf_map_values(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
SECTOR_ETF_MAP,
)
assert SECTOR_ETF_MAP["technology"] == "XLK"
assert SECTOR_ETF_MAP["finance"] == "XLF"
assert SECTOR_ETF_MAP["healthcare"] == "XLV"
assert SECTOR_ETF_MAP["energy"] == "XLE"
assert SECTOR_ETF_MAP["other"] == "SPY"
class TestGetSectorEtf:
@patch(
"tradingagents.agents.discovery.indicators.relative_strength.classify_sector"
)
def test_get_sector_etf_for_tech_stock(self, mock_classify):
from tradingagents.agents.discovery.indicators.relative_strength import (
get_sector_etf,
)
mock_classify.return_value = "technology"
etf = get_sector_etf("AAPL")
assert etf == "XLK"
@patch(
"tradingagents.agents.discovery.indicators.relative_strength.classify_sector"
)
def test_get_sector_etf_for_finance_stock(self, mock_classify):
from tradingagents.agents.discovery.indicators.relative_strength import (
get_sector_etf,
)
mock_classify.return_value = "finance"
etf = get_sector_etf("JPM")
assert etf == "XLF"
@patch(
"tradingagents.agents.discovery.indicators.relative_strength.classify_sector"
)
def test_get_sector_etf_for_unknown_sector(self, mock_classify):
from tradingagents.agents.discovery.indicators.relative_strength import (
get_sector_etf,
)
mock_classify.return_value = "other"
etf = get_sector_etf("XYZ")
assert etf == "SPY"
class TestCalculateRelativeStrengthMetrics:
@patch(
"tradingagents.agents.discovery.indicators.relative_strength._get_price_history"
)
@patch("tradingagents.agents.discovery.indicators.relative_strength.get_sector_etf")
def test_calculate_rs_metrics_returns_dict(
self, mock_sector_etf, mock_price_history
):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_relative_strength_metrics,
)
mock_sector_etf.return_value = "XLK"
base_prices = [100.0 + i * 0.5 for i in range(70)]
spy_prices = [100.0 + i * 0.3 for i in range(70)]
sector_prices = [100.0 + i * 0.4 for i in range(70)]
def price_history_side_effect(ticker, date, days):
if ticker == "AAPL":
return base_prices[-days:]
elif ticker == "SPY":
return spy_prices[-days:]
else:
return sector_prices[-days:]
mock_price_history.side_effect = price_history_side_effect
result = calculate_relative_strength_metrics("AAPL", "2024-01-15")
assert isinstance(result, dict)
assert "rs_vs_spy_5d" in result
assert "rs_vs_spy_20d" in result
assert "rs_vs_spy_60d" in result
assert "rs_vs_sector" in result
assert "sector_etf" in result
assert "relative_strength_score" in result
assert 0.0 <= result["relative_strength_score"] <= 1.0
@patch(
"tradingagents.agents.discovery.indicators.relative_strength._get_price_history"
)
@patch("tradingagents.agents.discovery.indicators.relative_strength.get_sector_etf")
def test_calculate_rs_metrics_handles_missing_benchmark_data(
self, mock_sector_etf, mock_price_history
):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_relative_strength_metrics,
)
mock_sector_etf.return_value = "XLK"
mock_price_history.return_value = []
result = calculate_relative_strength_metrics("INVALID", "2024-01-15")
assert isinstance(result, dict)
assert result["relative_strength_score"] == 0.5

View File

@ -0,0 +1,121 @@
import pytest
class TestCalculateStopLoss:
def test_calculate_stop_loss_default_multiplier(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_stop_loss,
)
stop = calculate_stop_loss(price=100.0, atr=2.0, multiplier=1.5)
assert stop == 97.0
def test_calculate_stop_loss_custom_multiplier(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_stop_loss,
)
stop = calculate_stop_loss(price=100.0, atr=2.0, multiplier=2.0)
assert stop == 96.0
def test_calculate_stop_loss_large_atr(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_stop_loss,
)
stop = calculate_stop_loss(price=100.0, atr=10.0, multiplier=1.5)
assert stop == 85.0
class TestCalculateRewardTarget:
def test_calculate_reward_target(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_reward_target,
)
target = calculate_reward_target(price=100.0, resistance=120.0)
assert target == 120.0
def test_calculate_reward_target_resistance_below_price(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_reward_target,
)
target = calculate_reward_target(price=100.0, resistance=90.0)
assert target == 90.0
class TestCalculateRiskRewardRatio:
def test_calculate_rr_ratio_good_trade(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_ratio,
)
rr = calculate_risk_reward_ratio(price=100.0, stop=95.0, target=115.0)
assert rr == 3.0
def test_calculate_rr_ratio_poor_trade(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_ratio,
)
rr = calculate_risk_reward_ratio(price=100.0, stop=95.0, target=102.0)
assert rr == pytest.approx(0.4, rel=0.01)
def test_calculate_rr_ratio_stop_at_price_returns_zero(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_ratio,
)
rr = calculate_risk_reward_ratio(price=100.0, stop=100.0, target=110.0)
assert rr == 0.0
def test_calculate_rr_ratio_target_below_price(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_ratio,
)
rr = calculate_risk_reward_ratio(price=100.0, stop=95.0, target=98.0)
assert rr < 0
class TestCalculateRiskRewardScore:
def test_excellent_rr_ratio_high_score(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_score,
)
score = calculate_risk_reward_score(rr_ratio=3.5)
assert 0.9 <= score <= 1.0
def test_good_rr_ratio_moderate_high_score(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_score,
)
score = calculate_risk_reward_score(rr_ratio=2.5)
assert 0.7 <= score <= 0.9
def test_acceptable_rr_ratio_moderate_score(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_score,
)
score = calculate_risk_reward_score(rr_ratio=1.5)
assert 0.4 <= score <= 0.7
def test_poor_rr_ratio_low_score(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_score,
)
score = calculate_risk_reward_score(rr_ratio=0.5)
assert 0.0 <= score <= 0.4
def test_negative_rr_ratio_very_low_score(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_score,
)
score = calculate_risk_reward_score(rr_ratio=-1.0)
assert score == 0.0

View File

@ -0,0 +1,224 @@
from unittest.mock import MagicMock, patch
import pytest
class TestFindSupportLevels:
def test_find_support_from_lows(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
find_support_levels,
)
lows = (
[100.0, 98.0, 97.0, 99.0, 96.0, 95.0, 97.0, 98.0, 94.0, 93.0]
+ [95.0, 96.0, 94.0, 93.0, 92.0, 91.0, 90.0, 89.0, 88.0, 87.0]
+ [88.0, 89.0, 87.0, 86.0, 85.0] * 6
)
support_20d, support_50d = find_support_levels(
lows, lookback_20=20, lookback_50=50
)
assert support_20d <= 100.0
assert support_50d <= 100.0
def test_find_support_empty_list(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
find_support_levels,
)
lows = []
support_20d, support_50d = find_support_levels(
lows, lookback_20=20, lookback_50=50
)
assert support_20d == 0.0
assert support_50d == 0.0
class TestFindResistanceLevels:
def test_find_resistance_from_highs(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
find_resistance_levels,
)
highs = (
[100.0, 102.0, 105.0, 103.0, 108.0, 106.0, 110.0, 107.0, 112.0, 109.0]
+ [111.0, 113.0, 115.0, 114.0, 117.0, 116.0, 120.0, 118.0, 122.0, 119.0]
+ [120.0, 121.0, 123.0, 124.0, 125.0] * 6
)
resistance_20d, resistance_50d = find_resistance_levels(
highs, lookback_20=20, lookback_50=50
)
assert resistance_20d >= 100.0
assert resistance_50d >= 100.0
def test_find_resistance_empty_list(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
find_resistance_levels,
)
highs = []
resistance_20d, resistance_50d = find_resistance_levels(
highs, lookback_20=20, lookback_50=50
)
assert resistance_20d == 0.0
assert resistance_50d == 0.0
class TestDetectSwingPoints:
def test_detect_swing_highs_and_lows(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
detect_swing_points,
)
prices = [
100,
102,
105,
103,
98,
95,
97,
100,
103,
107,
105,
102,
99,
96,
94,
97,
101,
]
swing_lows, swing_highs = detect_swing_points(prices, n_bars=3)
assert len(swing_lows) >= 0
assert len(swing_highs) >= 0
def test_detect_swing_points_short_list(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
detect_swing_points,
)
prices = [100, 105, 102]
swing_lows, swing_highs = detect_swing_points(prices, n_bars=3)
assert isinstance(swing_lows, list)
assert isinstance(swing_highs, list)
def test_detect_swing_points_empty_list(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
detect_swing_points,
)
prices = []
swing_lows, swing_highs = detect_swing_points(prices, n_bars=5)
assert swing_lows == []
assert swing_highs == []
class TestGetNearestLevels:
def test_get_nearest_support_and_resistance(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
get_nearest_levels,
)
price = 150.0
supports = [140.0, 145.0, 130.0, 120.0]
resistances = [155.0, 160.0, 170.0, 180.0]
nearest_support, nearest_resistance = get_nearest_levels(
price, supports, resistances
)
assert nearest_support == 145.0
assert nearest_resistance == 155.0
def test_get_nearest_levels_no_support_below(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
get_nearest_levels,
)
price = 100.0
supports = [110.0, 120.0, 130.0]
resistances = [150.0, 160.0]
nearest_support, nearest_resistance = get_nearest_levels(
price, supports, resistances
)
assert nearest_support == 0.0
assert nearest_resistance == 150.0
def test_get_nearest_levels_empty_lists(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
get_nearest_levels,
)
price = 150.0
supports = []
resistances = []
nearest_support, nearest_resistance = get_nearest_levels(
price, supports, resistances
)
assert nearest_support == 0.0
assert nearest_resistance == 0.0
class TestCalculateSupportResistanceMetrics:
@patch(
"tradingagents.agents.discovery.indicators.support_resistance._get_ohlc_data"
)
@patch("tradingagents.agents.discovery.indicators.support_resistance._get_atr")
def test_calculate_sr_metrics_returns_dict(self, mock_atr, mock_ohlc):
from tradingagents.agents.discovery.indicators.support_resistance import (
calculate_support_resistance_metrics,
)
mock_ohlc.return_value = {
"highs": [105.0 + i * 0.5 for i in range(60)],
"lows": [95.0 + i * 0.3 for i in range(60)],
"closes": [100.0 + i * 0.4 for i in range(60)],
"current_price": 125.0,
}
mock_atr.return_value = 2.5
result = calculate_support_resistance_metrics("AAPL", "2024-01-15")
assert isinstance(result, dict)
assert "support_level" in result
assert "resistance_level" in result
assert "atr" in result
assert "suggested_stop" in result
assert "reward_target" in result
assert "risk_reward_ratio" in result
assert "risk_reward_score" in result
@patch(
"tradingagents.agents.discovery.indicators.support_resistance._get_ohlc_data"
)
@patch("tradingagents.agents.discovery.indicators.support_resistance._get_atr")
def test_calculate_sr_metrics_handles_missing_data(self, mock_atr, mock_ohlc):
from tradingagents.agents.discovery.indicators.support_resistance import (
calculate_support_resistance_metrics,
)
mock_ohlc.return_value = {
"highs": [],
"lows": [],
"closes": [],
"current_price": None,
}
mock_atr.return_value = None
result = calculate_support_resistance_metrics("INVALID", "2024-01-15")
assert isinstance(result, dict)
assert result["risk_reward_score"] == 0.5

View File

@ -0,0 +1,201 @@
from unittest.mock import MagicMock, patch
import pytest
class TestCalculateVolumeRatio:
def test_volume_ratio_above_average(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_ratio,
)
ratio = calculate_volume_ratio(
current_volume=2_000_000, avg_volume_20d=1_000_000
)
assert ratio == 2.0
def test_volume_ratio_below_average(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_ratio,
)
ratio = calculate_volume_ratio(current_volume=500_000, avg_volume_20d=1_000_000)
assert ratio == 0.5
def test_volume_ratio_zero_average_returns_zero(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_ratio,
)
ratio = calculate_volume_ratio(current_volume=1_000_000, avg_volume_20d=0)
assert ratio == 0.0
class TestCalculateVolumeTrend:
def test_volume_trend_increasing(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_trend,
)
volume_series = [100_000, 150_000, 200_000, 250_000, 300_000]
trend, slope = calculate_volume_trend(volume_series)
assert trend == "increasing"
assert slope > 0
def test_volume_trend_decreasing(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_trend,
)
volume_series = [300_000, 250_000, 200_000, 150_000, 100_000]
trend, slope = calculate_volume_trend(volume_series)
assert trend == "decreasing"
assert slope < 0
def test_volume_trend_flat(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_trend,
)
volume_series = [100_000, 100_100, 99_900, 100_050, 99_950]
trend, slope = calculate_volume_trend(volume_series)
assert trend == "flat"
def test_volume_trend_empty_returns_flat(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_trend,
)
volume_series = []
trend, slope = calculate_volume_trend(volume_series)
assert trend == "flat"
assert slope == 0.0
class TestCalculateDollarVolume:
def test_dollar_volume_calculation(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_dollar_volume,
)
dollar_vol = calculate_dollar_volume(price=150.0, volume=1_000_000)
assert dollar_vol == 150_000_000.0
def test_dollar_volume_zero_price(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_dollar_volume,
)
dollar_vol = calculate_dollar_volume(price=0.0, volume=1_000_000)
assert dollar_vol == 0.0
class TestCalculateVolumeScore:
def test_volume_spike_with_positive_price_high_score(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_score,
)
score = calculate_volume_score(
volume_ratio=2.5,
trend="increasing",
dollar_volume=50_000_000.0,
price_change=5.0,
min_dollar_volume=1_000_000.0,
)
assert 0.8 <= score <= 1.0
def test_above_average_volume_moderate_score(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_score,
)
score = calculate_volume_score(
volume_ratio=1.5,
trend="increasing",
dollar_volume=10_000_000.0,
price_change=2.0,
min_dollar_volume=1_000_000.0,
)
assert 0.7 <= score <= 0.9
def test_normal_volume_neutral_score(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_score,
)
score = calculate_volume_score(
volume_ratio=1.0,
trend="flat",
dollar_volume=5_000_000.0,
price_change=0.5,
min_dollar_volume=1_000_000.0,
)
assert 0.3 <= score <= 0.6
def test_low_dollar_volume_penalized(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_score,
)
high_volume_score = calculate_volume_score(
volume_ratio=2.0,
trend="increasing",
dollar_volume=50_000_000.0,
price_change=3.0,
min_dollar_volume=1_000_000.0,
)
low_volume_score = calculate_volume_score(
volume_ratio=2.0,
trend="increasing",
dollar_volume=500_000.0,
price_change=3.0,
min_dollar_volume=1_000_000.0,
)
assert low_volume_score < high_volume_score
class TestCalculateVolumeMetrics:
@patch("tradingagents.agents.discovery.indicators.volume._get_volume_price_data")
def test_calculate_volume_metrics_returns_dict(self, mock_get_data):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_metrics,
)
mock_get_data.return_value = {
"volumes": [1_000_000] * 25,
"prices": [100.0] * 25,
"current_price": 105.0,
"current_volume": 1_500_000,
"price_change_pct": 5.0,
}
result = calculate_volume_metrics("AAPL", "2024-01-15")
assert isinstance(result, dict)
assert "volume_ratio" in result
assert "volume_trend" in result
assert "dollar_volume" in result
assert "volume_score" in result
assert 0.0 <= result["volume_score"] <= 1.0
@patch("tradingagents.agents.discovery.indicators.volume._get_volume_price_data")
def test_calculate_volume_metrics_handles_missing_data(self, mock_get_data):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_metrics,
)
mock_get_data.return_value = {
"volumes": [],
"prices": [],
"current_price": None,
"current_volume": None,
"price_change_pct": 0.0,
}
result = calculate_volume_metrics("INVALID", "2024-01-15")
assert isinstance(result, dict)
assert result["volume_score"] == 0.5

View File

@ -22,6 +22,7 @@ from .persistence import (
generate_markdown_summary,
save_discovery_result,
)
from .quantitative_models import QuantitativeMetrics
from .scorer import (
DEFAULT_DECAY_RATE,
DEFAULT_MAX_RESULTS,
@ -50,4 +51,5 @@ __all__ = [
"DEFAULT_MIN_MENTIONS",
"save_discovery_result",
"generate_markdown_summary",
"QuantitativeMetrics",
]

View File

@ -0,0 +1,61 @@
from tradingagents.agents.discovery.indicators.momentum import (
calculate_ema_direction,
calculate_macd_score,
calculate_momentum_score,
calculate_rsi_score,
calculate_sma_score,
)
from tradingagents.agents.discovery.indicators.relative_strength import (
SECTOR_ETF_MAP,
calculate_relative_strength,
calculate_relative_strength_metrics,
calculate_return,
get_sector_etf,
)
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_reward_target,
calculate_risk_reward_ratio,
calculate_risk_reward_score,
calculate_stop_loss,
)
from tradingagents.agents.discovery.indicators.support_resistance import (
calculate_support_resistance_metrics,
detect_swing_points,
find_resistance_levels,
find_support_levels,
get_nearest_levels,
)
from tradingagents.agents.discovery.indicators.volume import (
calculate_dollar_volume,
calculate_volume_metrics,
calculate_volume_ratio,
calculate_volume_score,
calculate_volume_trend,
)
__all__ = [
"calculate_rsi_score",
"calculate_macd_score",
"calculate_sma_score",
"calculate_ema_direction",
"calculate_momentum_score",
"calculate_volume_ratio",
"calculate_volume_trend",
"calculate_dollar_volume",
"calculate_volume_score",
"calculate_volume_metrics",
"SECTOR_ETF_MAP",
"calculate_return",
"calculate_relative_strength",
"get_sector_etf",
"calculate_relative_strength_metrics",
"find_support_levels",
"find_resistance_levels",
"detect_swing_points",
"get_nearest_levels",
"calculate_support_resistance_metrics",
"calculate_stop_loss",
"calculate_reward_target",
"calculate_risk_reward_ratio",
"calculate_risk_reward_score",
]

View File

@ -0,0 +1,349 @@
import logging
from datetime import datetime
import pandas as pd
logger = logging.getLogger(__name__)
def calculate_rsi_score(rsi: float) -> float:
if rsi <= 20:
return 1.0
elif rsi <= 30:
return 0.8 + (30 - rsi) / 50
elif rsi <= 35:
return 0.7 + (35 - rsi) / 50
elif rsi <= 50:
return 0.6 + (50 - rsi) / 150
elif rsi <= 65:
return 0.5 + (65 - rsi) / 150
elif rsi <= 70:
return 0.4 + (70 - rsi) / 50
elif rsi <= 80:
return 0.2 + (80 - rsi) / 50
else:
return max(0.0, 0.2 - (rsi - 80) / 100)
def calculate_macd_score(macd: float, signal: float, histogram: float) -> float:
score = 0.5
if macd > signal:
crossover_strength = min((macd - signal) / max(abs(signal), 0.01), 1.0)
score += 0.25 * crossover_strength
else:
crossover_weakness = min((signal - macd) / max(abs(signal), 0.01), 1.0)
score -= 0.25 * crossover_weakness
if histogram > 0:
histogram_bonus = min(histogram / max(abs(macd), 0.01), 1.0) * 0.25
score += histogram_bonus
else:
histogram_penalty = min(abs(histogram) / max(abs(macd), 0.01), 1.0) * 0.25
score -= histogram_penalty
return max(0.0, min(1.0, score))
def calculate_sma_score(price: float, sma50: float, sma200: float) -> float:
if sma50 == 0 or sma200 == 0:
return 0.5
score = 0.5
pct_vs_sma50 = (price - sma50) / sma50
pct_vs_sma200 = (price - sma200) / sma200
if pct_vs_sma50 > 0:
score += min(pct_vs_sma50 * 2, 0.25)
else:
score += max(pct_vs_sma50 * 2, -0.25)
if pct_vs_sma200 > 0:
score += min(pct_vs_sma200 * 2, 0.25)
else:
score += max(pct_vs_sma200 * 2, -0.25)
if price > sma50 > sma200:
score += 0.15
return max(0.0, min(1.0, score))
def calculate_ema_direction(ema_values: list[float]) -> str:
if len(ema_values) < 2:
return "flat"
first_half = ema_values[: len(ema_values) // 2]
second_half = ema_values[len(ema_values) // 2 :]
if not first_half or not second_half:
return "flat"
first_avg = sum(first_half) / len(first_half)
second_avg = sum(second_half) / len(second_half)
pct_change = (second_avg - first_avg) / first_avg if first_avg != 0 else 0
if pct_change > 0.01:
return "up"
elif pct_change < -0.01:
return "down"
else:
return "flat"
def _get_stock_stats_bulk(symbol: str, indicator: str, curr_date: str) -> dict:
import os
from stockstats import wrap
from tradingagents.dataflows.config import get_config
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
if not online:
try:
data = pd.read_csv(
os.path.join(
config.get("data_cache_dir", "data"),
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
df = wrap(data)
except FileNotFoundError:
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
else:
import yfinance as yf
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
data["Date"] = pd.to_datetime(data["Date"])
else:
data = yf.download(
symbol,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
df[indicator]
indicator_series = df[indicator].apply(lambda x: "N/A" if pd.isna(x) else str(x))
result_dict = dict(zip(df["Date"], indicator_series, strict=False))
return result_dict
def _get_price_data(symbol: str, curr_date: str) -> dict:
import os
from tradingagents.dataflows.config import get_config
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
if not online:
try:
data = pd.read_csv(
os.path.join(
config.get("data_cache_dir", "data"),
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
except FileNotFoundError:
return {}
else:
import yfinance as yf
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
else:
data = yf.download(
symbol,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
data["Date"] = pd.to_datetime(data["Date"]).dt.strftime("%Y-%m-%d")
result_dict = {}
for _, row in data.iterrows():
result_dict[row["Date"]] = row["Close"]
return result_dict
def calculate_momentum_score(ticker: str, curr_date: str) -> dict:
result = {
"rsi": None,
"macd": None,
"macd_signal": None,
"macd_histogram": None,
"price_vs_sma50": None,
"price_vs_sma200": None,
"ema10_direction": None,
"momentum_score": 0.5,
}
try:
rsi_data = _get_stock_stats_bulk(ticker, "rsi", curr_date)
macd_data = _get_stock_stats_bulk(ticker, "macd", curr_date)
macds_data = _get_stock_stats_bulk(ticker, "macds", curr_date)
macdh_data = _get_stock_stats_bulk(ticker, "macdh", curr_date)
sma50_data = _get_stock_stats_bulk(ticker, "close_50_sma", curr_date)
sma200_data = _get_stock_stats_bulk(ticker, "close_200_sma", curr_date)
ema10_data = _get_stock_stats_bulk(ticker, "close_10_ema", curr_date)
price_data = _get_price_data(ticker, curr_date)
rsi_value = None
if curr_date in rsi_data and rsi_data[curr_date] != "N/A":
try:
rsi_value = float(rsi_data[curr_date])
result["rsi"] = rsi_value
except (ValueError, TypeError):
pass
macd_value = None
macds_value = None
macdh_value = None
if curr_date in macd_data and macd_data[curr_date] != "N/A":
try:
macd_value = float(macd_data[curr_date])
result["macd"] = macd_value
except (ValueError, TypeError):
pass
if curr_date in macds_data and macds_data[curr_date] != "N/A":
try:
macds_value = float(macds_data[curr_date])
result["macd_signal"] = macds_value
except (ValueError, TypeError):
pass
if curr_date in macdh_data and macdh_data[curr_date] != "N/A":
try:
macdh_value = float(macdh_data[curr_date])
result["macd_histogram"] = macdh_value
except (ValueError, TypeError):
pass
current_price = None
sma50_value = None
sma200_value = None
if curr_date in price_data:
try:
current_price = float(price_data[curr_date])
except (ValueError, TypeError):
pass
if curr_date in sma50_data and sma50_data[curr_date] != "N/A":
try:
sma50_value = float(sma50_data[curr_date])
except (ValueError, TypeError):
pass
if curr_date in sma200_data and sma200_data[curr_date] != "N/A":
try:
sma200_value = float(sma200_data[curr_date])
except (ValueError, TypeError):
pass
if current_price and sma50_value:
result["price_vs_sma50"] = (
(current_price - sma50_value) / sma50_value
) * 100
if current_price and sma200_value:
result["price_vs_sma200"] = (
(current_price - sma200_value) / sma200_value
) * 100
ema_values = []
sorted_dates = sorted(
[d for d in ema10_data.keys() if d <= curr_date], reverse=True
)[:10]
for date in sorted_dates:
if ema10_data[date] != "N/A":
try:
ema_values.append(float(ema10_data[date]))
except (ValueError, TypeError):
pass
if ema_values:
result["ema10_direction"] = calculate_ema_direction(ema_values[::-1])
scores = []
weights = []
if rsi_value is not None:
scores.append(calculate_rsi_score(rsi_value))
weights.append(0.25)
if (
macd_value is not None
and macds_value is not None
and macdh_value is not None
):
scores.append(calculate_macd_score(macd_value, macds_value, macdh_value))
weights.append(0.35)
if current_price and sma50_value and sma200_value:
scores.append(calculate_sma_score(current_price, sma50_value, sma200_value))
weights.append(0.40)
if scores and weights:
total_weight = sum(weights)
result["momentum_score"] = (
sum(s * w for s, w in zip(scores, weights, strict=False)) / total_weight
)
else:
result["momentum_score"] = 0.5
except (KeyError, ValueError, FileNotFoundError, RuntimeError) as e:
logger.warning("Failed to calculate momentum score for %s: %s", ticker, str(e))
result["momentum_score"] = 0.5
return result

View File

@ -0,0 +1,181 @@
import logging
import os
import pandas as pd
from tradingagents.dataflows.trending.sector_classifier import classify_sector
logger = logging.getLogger(__name__)
SECTOR_ETF_MAP = {
"technology": "XLK",
"finance": "XLF",
"healthcare": "XLV",
"energy": "XLE",
"consumer_goods": "XLY",
"industrials": "XLI",
"other": "SPY",
}
def calculate_return(prices: list[float], days: int) -> float:
if len(prices) < 2:
return 0.0
start_idx = max(0, len(prices) - days)
start_price = prices[start_idx]
end_price = prices[-1]
if start_price == 0:
return 0.0
return ((end_price - start_price) / start_price) * 100
def calculate_relative_strength(stock_return: float, benchmark_return: float) -> float:
return stock_return - benchmark_return
def get_sector_etf(ticker: str) -> str:
sector = classify_sector(ticker)
return SECTOR_ETF_MAP.get(sector, "SPY")
def _get_price_history(ticker: str, curr_date: str, days: int) -> list[float]:
from tradingagents.dataflows.config import get_config
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
try:
if not online:
data = pd.read_csv(
os.path.join(
config.get("data_cache_dir", "data"),
f"{ticker}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
else:
import yfinance as yf
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{ticker}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
else:
data = yf.download(
ticker,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
data["Date"] = pd.to_datetime(data["Date"]).dt.strftime("%Y-%m-%d")
data = data[data["Date"] <= curr_date].tail(days + 5)
return data["Close"].tolist()
except (FileNotFoundError, KeyError, ValueError) as e:
logger.warning("Failed to get price history for %s: %s", ticker, str(e))
return []
def _calculate_rs_score(
rs_5d: float, rs_20d: float, rs_60d: float, rs_sector: float
) -> float:
score = 0.5
if rs_20d > 5:
score += min(rs_20d / 20, 0.25)
elif rs_20d < -5:
score -= min(abs(rs_20d) / 20, 0.25)
if rs_5d > 3:
score += min(rs_5d / 15, 0.15)
elif rs_5d < -3:
score -= min(abs(rs_5d) / 15, 0.15)
if rs_60d > 10:
score += min(rs_60d / 40, 0.1)
elif rs_60d < -10:
score -= min(abs(rs_60d) / 40, 0.1)
if rs_sector > 3:
score += min(rs_sector / 15, 0.1)
elif rs_sector < -3:
score -= min(abs(rs_sector) / 15, 0.1)
return max(0.0, min(1.0, score))
def calculate_relative_strength_metrics(ticker: str, curr_date: str) -> dict:
result = {
"rs_vs_spy_5d": None,
"rs_vs_spy_20d": None,
"rs_vs_spy_60d": None,
"rs_vs_sector": None,
"sector_etf": None,
"relative_strength_score": 0.5,
}
try:
stock_prices = _get_price_history(ticker, curr_date, 70)
spy_prices = _get_price_history("SPY", curr_date, 70)
if not stock_prices or not spy_prices:
return result
sector_etf = get_sector_etf(ticker)
result["sector_etf"] = sector_etf
sector_prices = []
if sector_etf != "SPY":
sector_prices = _get_price_history(sector_etf, curr_date, 70)
stock_5d = calculate_return(stock_prices, 5)
stock_20d = calculate_return(stock_prices, 20)
stock_60d = calculate_return(stock_prices, 60)
spy_5d = calculate_return(spy_prices, 5)
spy_20d = calculate_return(spy_prices, 20)
spy_60d = calculate_return(spy_prices, 60)
result["rs_vs_spy_5d"] = calculate_relative_strength(stock_5d, spy_5d)
result["rs_vs_spy_20d"] = calculate_relative_strength(stock_20d, spy_20d)
result["rs_vs_spy_60d"] = calculate_relative_strength(stock_60d, spy_60d)
if sector_prices and sector_etf != "SPY":
sector_20d = calculate_return(sector_prices, 20)
result["rs_vs_sector"] = calculate_relative_strength(stock_20d, sector_20d)
else:
result["rs_vs_sector"] = result["rs_vs_spy_20d"]
result["relative_strength_score"] = _calculate_rs_score(
result["rs_vs_spy_5d"],
result["rs_vs_spy_20d"],
result["rs_vs_spy_60d"],
result["rs_vs_sector"],
)
except (KeyError, ValueError, RuntimeError) as e:
logger.warning(
"Failed to calculate relative strength for %s: %s", ticker, str(e)
)
return result

View File

@ -0,0 +1,34 @@
import logging
logger = logging.getLogger(__name__)
def calculate_stop_loss(price: float, atr: float, multiplier: float = 1.5) -> float:
return price - (atr * multiplier)
def calculate_reward_target(price: float, resistance: float) -> float:
return resistance
def calculate_risk_reward_ratio(price: float, stop: float, target: float) -> float:
risk = price - stop
if risk == 0:
return 0.0
reward = target - price
return reward / risk
def calculate_risk_reward_score(rr_ratio: float) -> float:
if rr_ratio < 0:
return 0.0
if rr_ratio >= 3.0:
return 0.9 + min((rr_ratio - 3.0) / 10, 0.1)
elif rr_ratio >= 2.0:
return 0.7 + (rr_ratio - 2.0) / 5
elif rr_ratio >= 1.0:
return 0.4 + (rr_ratio - 1.0) * 0.3
else:
return rr_ratio * 0.4

View File

@ -0,0 +1,260 @@
import logging
import os
import pandas as pd
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_reward_target,
calculate_risk_reward_ratio,
calculate_risk_reward_score,
calculate_stop_loss,
)
logger = logging.getLogger(__name__)
def find_support_levels(
lows: list[float], lookback_20: int, lookback_50: int
) -> tuple[float, float]:
if not lows:
return (0.0, 0.0)
support_20d = (
min(lows[-lookback_20:])
if len(lows) >= lookback_20
else min(lows)
if lows
else 0.0
)
support_50d = (
min(lows[-lookback_50:])
if len(lows) >= lookback_50
else min(lows)
if lows
else 0.0
)
return (support_20d, support_50d)
def find_resistance_levels(
highs: list[float], lookback_20: int, lookback_50: int
) -> tuple[float, float]:
if not highs:
return (0.0, 0.0)
resistance_20d = (
max(highs[-lookback_20:])
if len(highs) >= lookback_20
else max(highs)
if highs
else 0.0
)
resistance_50d = (
max(highs[-lookback_50:])
if len(highs) >= lookback_50
else max(highs)
if highs
else 0.0
)
return (resistance_20d, resistance_50d)
def detect_swing_points(
prices: list[float], n_bars: int = 5
) -> tuple[list[float], list[float]]:
if len(prices) < (2 * n_bars + 1):
return ([], [])
swing_lows = []
swing_highs = []
for i in range(n_bars, len(prices) - n_bars):
is_swing_low = True
is_swing_high = True
current = prices[i]
for j in range(1, n_bars + 1):
if prices[i - j] <= current or prices[i + j] <= current:
is_swing_low = False
if prices[i - j] >= current or prices[i + j] >= current:
is_swing_high = False
if is_swing_low:
swing_lows.append(current)
if is_swing_high:
swing_highs.append(current)
return (swing_lows, swing_highs)
def get_nearest_levels(
price: float, supports: list[float], resistances: list[float]
) -> tuple[float, float]:
supports_below = [s for s in supports if s < price]
resistances_above = [r for r in resistances if r > price]
nearest_support = max(supports_below) if supports_below else 0.0
nearest_resistance = min(resistances_above) if resistances_above else 0.0
return (nearest_support, nearest_resistance)
def _get_ohlc_data(ticker: str, curr_date: str) -> dict:
from tradingagents.dataflows.config import get_config
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
result = {
"highs": [],
"lows": [],
"closes": [],
"current_price": None,
}
try:
if not online:
data = pd.read_csv(
os.path.join(
config.get("data_cache_dir", "data"),
f"{ticker}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
else:
import yfinance as yf
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{ticker}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
else:
data = yf.download(
ticker,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
data["Date"] = pd.to_datetime(data["Date"]).dt.strftime("%Y-%m-%d")
data = data[data["Date"] <= curr_date].tail(60)
if len(data) > 0:
result["highs"] = data["High"].tolist()
result["lows"] = data["Low"].tolist()
result["closes"] = data["Close"].tolist()
result["current_price"] = data["Close"].iloc[-1] if len(data) > 0 else None
except (FileNotFoundError, KeyError, ValueError) as e:
logger.warning("Failed to get OHLC data for %s: %s", ticker, str(e))
return result
def _get_atr(ticker: str, curr_date: str) -> float | None:
from tradingagents.agents.discovery.indicators.momentum import _get_stock_stats_bulk
try:
atr_data = _get_stock_stats_bulk(ticker, "atr", curr_date)
if curr_date in atr_data and atr_data[curr_date] != "N/A":
return float(atr_data[curr_date])
except (KeyError, ValueError, FileNotFoundError) as e:
logger.warning("Failed to get ATR for %s: %s", ticker, str(e))
return None
def calculate_support_resistance_metrics(ticker: str, curr_date: str) -> dict:
result = {
"support_level": None,
"resistance_level": None,
"atr": None,
"suggested_stop": None,
"reward_target": None,
"risk_reward_ratio": None,
"risk_reward_score": 0.5,
}
try:
ohlc = _get_ohlc_data(ticker, curr_date)
atr = _get_atr(ticker, curr_date)
highs = ohlc["highs"]
lows = ohlc["lows"]
closes = ohlc["closes"]
current_price = ohlc["current_price"]
if not highs or not lows or current_price is None:
return result
result["atr"] = atr
support_20d, support_50d = find_support_levels(lows, 20, 50)
resistance_20d, resistance_50d = find_resistance_levels(highs, 20, 50)
swing_lows, swing_highs = detect_swing_points(closes, n_bars=5)
all_supports = [s for s in [support_20d, support_50d] + swing_lows if s > 0]
all_resistances = [
r for r in [resistance_20d, resistance_50d] + swing_highs if r > 0
]
nearest_support, nearest_resistance = get_nearest_levels(
current_price, all_supports, all_resistances
)
result["support_level"] = (
nearest_support if nearest_support > 0 else support_20d
)
result["resistance_level"] = (
nearest_resistance if nearest_resistance > 0 else resistance_20d
)
if atr and atr > 0:
result["suggested_stop"] = calculate_stop_loss(
current_price, atr, multiplier=1.5
)
else:
pct_to_support = (
(current_price - result["support_level"]) / current_price
if result["support_level"] > 0
else 0.02
)
result["suggested_stop"] = current_price * (1 - max(pct_to_support, 0.02))
result["reward_target"] = calculate_reward_target(
current_price, result["resistance_level"]
)
if result["suggested_stop"] and result["reward_target"]:
result["risk_reward_ratio"] = calculate_risk_reward_ratio(
current_price, result["suggested_stop"], result["reward_target"]
)
result["risk_reward_score"] = calculate_risk_reward_score(
result["risk_reward_ratio"]
)
except (KeyError, ValueError, RuntimeError) as e:
logger.warning(
"Failed to calculate support/resistance metrics for %s: %s", ticker, str(e)
)
return result

View File

@ -0,0 +1,203 @@
import logging
logger = logging.getLogger(__name__)
def determine_signal_from_indicators(
rsi: float | None,
macd: float | None,
macd_signal: float | None,
price_vs_sma: float | None,
ema_direction: str | None,
) -> str:
bullish_signals = 0
bearish_signals = 0
total_signals = 0
if rsi is not None:
total_signals += 1
if rsi < 40:
bullish_signals += 1
elif rsi > 60:
bearish_signals += 1
if macd is not None and macd_signal is not None:
total_signals += 1
if macd > macd_signal:
bullish_signals += 1
else:
bearish_signals += 1
if price_vs_sma is not None:
total_signals += 1
if price_vs_sma > 0:
bullish_signals += 1
elif price_vs_sma < 0:
bearish_signals += 1
if ema_direction is not None:
total_signals += 1
if ema_direction == "up":
bullish_signals += 1
elif ema_direction == "down":
bearish_signals += 1
if total_signals == 0:
return "neutral"
bullish_ratio = bullish_signals / total_signals
bearish_ratio = bearish_signals / total_signals
if bullish_ratio >= 0.6:
return "bullish"
elif bearish_ratio >= 0.6:
return "bearish"
else:
return "neutral"
def calculate_timeframe_signals(
momentum_data: dict,
relative_strength_data: dict,
) -> dict:
result = {
"short_term_signal": "neutral",
"medium_term_signal": "neutral",
"long_term_signal": "neutral",
"timeframe_alignment": "neutral",
"signal_strength": 0.5,
}
try:
rsi = momentum_data.get("rsi")
macd = momentum_data.get("macd")
macd_signal = momentum_data.get("macd_signal")
ema_direction = momentum_data.get("ema10_direction")
price_vs_sma50 = momentum_data.get("price_vs_sma50")
price_vs_sma200 = momentum_data.get("price_vs_sma200")
rs_5d = relative_strength_data.get("rs_vs_spy_5d")
rs_20d = relative_strength_data.get("rs_vs_spy_20d")
rs_60d = relative_strength_data.get("rs_vs_spy_60d")
short_bullish = 0
short_bearish = 0
short_total = 0
if rsi is not None:
short_total += 1
if rsi < 35:
short_bullish += 1
elif rsi > 65:
short_bearish += 1
if ema_direction == "up":
short_total += 1
short_bullish += 1
elif ema_direction == "down":
short_total += 1
short_bearish += 1
if rs_5d is not None:
short_total += 1
if rs_5d > 0:
short_bullish += 1
elif rs_5d < 0:
short_bearish += 1
if short_total > 0:
if short_bullish / short_total >= 0.6:
result["short_term_signal"] = "bullish"
elif short_bearish / short_total >= 0.6:
result["short_term_signal"] = "bearish"
med_bullish = 0
med_bearish = 0
med_total = 0
if macd is not None and macd_signal is not None:
med_total += 1
if macd > macd_signal:
med_bullish += 1
else:
med_bearish += 1
if price_vs_sma50 is not None:
med_total += 1
if price_vs_sma50 > 0:
med_bullish += 1
elif price_vs_sma50 < -2:
med_bearish += 1
if rs_20d is not None:
med_total += 1
if rs_20d > 0:
med_bullish += 1
elif rs_20d < 0:
med_bearish += 1
if med_total > 0:
if med_bullish / med_total >= 0.6:
result["medium_term_signal"] = "bullish"
elif med_bearish / med_total >= 0.6:
result["medium_term_signal"] = "bearish"
long_bullish = 0
long_bearish = 0
long_total = 0
if price_vs_sma200 is not None:
long_total += 1
if price_vs_sma200 > 0:
long_bullish += 1
elif price_vs_sma200 < -5:
long_bearish += 1
if rs_60d is not None:
long_total += 1
if rs_60d > 0:
long_bullish += 1
elif rs_60d < 0:
long_bearish += 1
if price_vs_sma50 is not None and price_vs_sma200 is not None:
long_total += 1
if price_vs_sma50 > 0 and price_vs_sma200 > 0:
long_bullish += 1
elif price_vs_sma50 < 0 and price_vs_sma200 < 0:
long_bearish += 1
if long_total > 0:
if long_bullish / long_total >= 0.6:
result["long_term_signal"] = "bullish"
elif long_bearish / long_total >= 0.6:
result["long_term_signal"] = "bearish"
signals = [
result["short_term_signal"],
result["medium_term_signal"],
result["long_term_signal"],
]
bullish_count = signals.count("bullish")
bearish_count = signals.count("bearish")
if bullish_count == 3:
result["timeframe_alignment"] = "aligned_bullish"
result["signal_strength"] = 1.0
elif bearish_count == 3:
result["timeframe_alignment"] = "aligned_bearish"
result["signal_strength"] = 0.0
elif bullish_count >= 2:
result["timeframe_alignment"] = "mixed"
result["signal_strength"] = 0.7
elif bearish_count >= 2:
result["timeframe_alignment"] = "mixed"
result["signal_strength"] = 0.3
else:
result["timeframe_alignment"] = "neutral"
result["signal_strength"] = 0.5
except (KeyError, TypeError, ValueError) as e:
logger.warning("Failed to calculate timeframe signals: %s", str(e))
return result

View File

@ -0,0 +1,213 @@
import logging
import os
import numpy as np
import pandas as pd
from tradingagents.config import get_settings
logger = logging.getLogger(__name__)
def calculate_volume_ratio(current_volume: float, avg_volume_20d: float) -> float:
if avg_volume_20d == 0:
return 0.0
return current_volume / avg_volume_20d
def calculate_volume_trend(volume_series: list[float]) -> tuple[str, float]:
if len(volume_series) < 2:
return ("flat", 0.0)
x = np.arange(len(volume_series))
y = np.array(volume_series)
coeffs = np.polyfit(x, y, 1)
slope = coeffs[0]
avg_volume = np.mean(y)
if avg_volume == 0:
return ("flat", 0.0)
normalized_slope = slope / avg_volume
if normalized_slope > 0.02:
return ("increasing", float(slope))
elif normalized_slope < -0.02:
return ("decreasing", float(slope))
else:
return ("flat", float(slope))
def calculate_dollar_volume(price: float, volume: float) -> float:
return price * volume
def calculate_volume_score(
volume_ratio: float,
trend: str,
dollar_volume: float,
price_change: float,
min_dollar_volume: float,
) -> float:
score = 0.5
if volume_ratio >= 2.0:
if price_change > 0:
score += 0.35
else:
score += 0.15
elif volume_ratio >= 1.5:
if price_change > 0:
score += 0.25
else:
score += 0.10
elif volume_ratio >= 1.0:
score += 0.05
else:
score -= 0.1
if trend == "increasing":
score += 0.1
elif trend == "decreasing":
score -= 0.05
if dollar_volume < min_dollar_volume:
liquidity_penalty = min(
0.3, (min_dollar_volume - dollar_volume) / min_dollar_volume * 0.3
)
score -= liquidity_penalty
return max(0.0, min(1.0, score))
def _get_volume_price_data(ticker: str, curr_date: str) -> dict:
from tradingagents.dataflows.config import get_config
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
result = {
"volumes": [],
"prices": [],
"current_price": None,
"current_volume": None,
"price_change_pct": 0.0,
}
try:
if not online:
data = pd.read_csv(
os.path.join(
config.get("data_cache_dir", "data"),
f"{ticker}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
else:
import yfinance as yf
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{ticker}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
else:
data = yf.download(
ticker,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
data["Date"] = pd.to_datetime(data["Date"]).dt.strftime("%Y-%m-%d")
data = data[data["Date"] <= curr_date].tail(30)
if len(data) > 0:
volumes = data["Volume"].tolist()
prices = data["Close"].tolist()
result["volumes"] = volumes[:-1] if len(volumes) > 1 else volumes
result["prices"] = prices[:-1] if len(prices) > 1 else prices
result["current_volume"] = volumes[-1] if volumes else None
result["current_price"] = prices[-1] if prices else None
if len(prices) >= 2:
result["price_change_pct"] = (
(prices[-1] - prices[-2]) / prices[-2] * 100
if prices[-2] != 0
else 0
)
except (FileNotFoundError, KeyError, ValueError) as e:
logger.warning("Failed to get volume/price data for %s: %s", ticker, str(e))
return result
def calculate_volume_metrics(ticker: str, curr_date: str) -> dict:
result = {
"volume_ratio": None,
"volume_trend": None,
"volume_trend_slope": None,
"dollar_volume": None,
"volume_score": 0.5,
}
try:
settings = get_settings()
min_dollar_volume = settings.min_dollar_volume
data = _get_volume_price_data(ticker, curr_date)
volumes = data["volumes"]
current_volume = data["current_volume"]
current_price = data["current_price"]
price_change = data["price_change_pct"]
if not volumes or current_volume is None or current_price is None:
return result
avg_volume_20d = (
sum(volumes[-20:]) / min(len(volumes[-20:]), 20) if volumes else 0
)
volume_ratio = calculate_volume_ratio(current_volume, avg_volume_20d)
result["volume_ratio"] = volume_ratio
trend_volumes = (
volumes[-10:] + [current_volume] if volumes else [current_volume]
)
trend, slope = calculate_volume_trend(trend_volumes)
result["volume_trend"] = trend
result["volume_trend_slope"] = slope
dollar_volume = calculate_dollar_volume(current_price, current_volume)
result["dollar_volume"] = dollar_volume
result["volume_score"] = calculate_volume_score(
volume_ratio=volume_ratio,
trend=trend,
dollar_volume=dollar_volume,
price_change=price_change,
min_dollar_volume=min_dollar_volume,
)
except (KeyError, ValueError, RuntimeError) as e:
logger.warning("Failed to calculate volume metrics for %s: %s", ticker, str(e))
return result

View File

@ -1,7 +1,12 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from tradingagents.agents.discovery.quantitative_models import QuantitativeMetrics
class DiscoveryStatus(Enum):
@ -50,7 +55,7 @@ class NewsArticle:
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "NewsArticle":
def from_dict(cls, data: dict[str, Any]) -> NewsArticle:
return cls(
title=data["title"],
source=data["source"],
@ -72,9 +77,11 @@ class TrendingStock:
event_type: EventCategory
news_summary: str
source_articles: list[NewsArticle]
quantitative_metrics: QuantitativeMetrics | None = None
conviction_score: float | None = None
def to_dict(self) -> dict[str, Any]:
return {
result = {
"ticker": self.ticker,
"company_name": self.company_name,
"score": self.score,
@ -85,9 +92,24 @@ class TrendingStock:
"news_summary": self.news_summary,
"source_articles": [article.to_dict() for article in self.source_articles],
}
if self.quantitative_metrics is not None:
result["quantitative_metrics"] = self.quantitative_metrics.to_dict()
if self.conviction_score is not None:
result["conviction_score"] = self.conviction_score
return result
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "TrendingStock":
def from_dict(cls, data: dict[str, Any]) -> TrendingStock:
from tradingagents.agents.discovery.quantitative_models import (
QuantitativeMetrics,
)
quantitative_metrics = None
if data.get("quantitative_metrics"):
quantitative_metrics = QuantitativeMetrics.from_dict(
data["quantitative_metrics"]
)
return cls(
ticker=data["ticker"],
company_name=data["company_name"],
@ -100,6 +122,8 @@ class TrendingStock:
source_articles=[
NewsArticle.from_dict(article) for article in data["source_articles"]
],
quantitative_metrics=quantitative_metrics,
conviction_score=data.get("conviction_score"),
)
@ -125,7 +149,7 @@ class DiscoveryRequest:
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "DiscoveryRequest":
def from_dict(cls, data: dict[str, Any]) -> DiscoveryRequest:
return cls(
lookback_period=data["lookback_period"],
sector_filter=(
@ -165,7 +189,7 @@ class DiscoveryResult:
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "DiscoveryResult":
def from_dict(cls, data: dict[str, Any]) -> DiscoveryResult:
return cls(
request=DiscoveryRequest.from_dict(data["request"]),
trending_stocks=[

View File

@ -0,0 +1,47 @@
import logging
from collections import OrderedDict
from datetime import datetime
from typing import Any
import pandas as pd
logger = logging.getLogger(__name__)
MAX_CACHE_SIZE = 100
_price_data_cache: OrderedDict[str, tuple[pd.DataFrame, datetime]] = OrderedDict()
def get_cached_price_data(ticker: str) -> pd.DataFrame | None:
if ticker not in _price_data_cache:
return None
df, timestamp = _price_data_cache[ticker]
_price_data_cache.move_to_end(ticker)
return df.copy()
def set_cached_price_data(ticker: str, data: pd.DataFrame) -> None:
global _price_data_cache
if ticker in _price_data_cache:
_price_data_cache.move_to_end(ticker)
_price_data_cache[ticker] = (data.copy(), datetime.now())
return
while len(_price_data_cache) >= MAX_CACHE_SIZE:
_price_data_cache.popitem(last=False)
_price_data_cache[ticker] = (data.copy(), datetime.now())
logger.debug(
"Cached price data for %s (cache size: %d)", ticker, len(_price_data_cache)
)
def clear_run_cache() -> None:
global _price_data_cache
count = len(_price_data_cache)
_price_data_cache.clear()
logger.debug("Cleared run cache (%d entries removed)", count)

View File

@ -0,0 +1,91 @@
from typing import Any
from pydantic import BaseModel, Field, field_validator
class QuantitativeMetrics(BaseModel):
momentum_score: float = Field(ge=0.0, le=1.0)
volume_score: float = Field(ge=0.0, le=1.0)
relative_strength_score: float = Field(ge=0.0, le=1.0)
risk_reward_score: float = Field(ge=0.0, le=1.0)
rsi: float | None = None
macd: float | None = None
macd_signal: float | None = None
macd_histogram: float | None = None
price_vs_sma50: float | None = None
price_vs_sma200: float | None = None
ema10_direction: str | None = None
volume_ratio: float | None = None
volume_trend: str | None = None
dollar_volume: float | None = None
rs_vs_spy_5d: float | None = None
rs_vs_spy_20d: float | None = None
rs_vs_spy_60d: float | None = None
rs_vs_sector: float | None = None
sector_etf: str | None = None
support_level: float | None = None
resistance_level: float | None = None
atr: float | None = None
suggested_stop: float | None = None
reward_target: float | None = None
risk_reward_ratio: float | None = None
timeframe_alignment: str | None = None
short_term_signal: str | None = None
medium_term_signal: str | None = None
long_term_signal: str | None = None
signal_strength: float | None = None
quantitative_score: float = Field(ge=0.0, le=1.0)
@field_validator("ema10_direction")
@classmethod
def validate_ema10_direction(cls, v: str | None) -> str | None:
if v is None:
return v
valid_directions = {"up", "down", "flat"}
if v not in valid_directions:
raise ValueError(f"ema10_direction must be one of {valid_directions}")
return v
@field_validator("volume_trend")
@classmethod
def validate_volume_trend(cls, v: str | None) -> str | None:
if v is None:
return v
valid_trends = {"increasing", "decreasing", "flat"}
if v not in valid_trends:
raise ValueError(f"volume_trend must be one of {valid_trends}")
return v
@field_validator("timeframe_alignment")
@classmethod
def validate_timeframe_alignment(cls, v: str | None) -> str | None:
if v is None:
return v
valid_alignments = {"aligned_bullish", "aligned_bearish", "mixed", "neutral"}
if v not in valid_alignments:
raise ValueError(f"timeframe_alignment must be one of {valid_alignments}")
return v
@field_validator("short_term_signal", "medium_term_signal", "long_term_signal")
@classmethod
def validate_signal(cls, v: str | None) -> str | None:
if v is None:
return v
valid_signals = {"bullish", "bearish", "neutral"}
if v not in valid_signals:
raise ValueError(f"signal must be one of {valid_signals}")
return v
def to_dict(self) -> dict[str, Any]:
return self.model_dump()
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "QuantitativeMetrics":
return cls(**data)

View File

@ -0,0 +1,178 @@
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import TYPE_CHECKING
from tradingagents.agents.discovery.indicators import (
calculate_momentum_score,
calculate_relative_strength_metrics,
calculate_support_resistance_metrics,
calculate_volume_metrics,
)
from tradingagents.agents.discovery.indicators.timeframe import (
calculate_timeframe_signals,
)
from tradingagents.agents.discovery.quantitative_cache import clear_run_cache
from tradingagents.agents.discovery.quantitative_models import QuantitativeMetrics
from tradingagents.config import QuantitativeWeightsConfig, get_settings
if TYPE_CHECKING:
from tradingagents.agents.discovery.models import TrendingStock
logger = logging.getLogger(__name__)
def calculate_unified_score(
momentum: float,
volume: float,
rs: float,
rr: float,
weights: QuantitativeWeightsConfig,
) -> float:
score = (
momentum * weights.momentum_weight
+ volume * weights.volume_weight
+ rs * weights.relative_strength_weight
+ rr * weights.risk_reward_weight
)
return max(0.0, min(1.0, score))
def calculate_single_stock_metrics(
ticker: str,
curr_date: str,
) -> QuantitativeMetrics | None:
try:
settings = get_settings()
weights = settings.quantitative_weights
momentum = calculate_momentum_score(ticker, curr_date)
volume = calculate_volume_metrics(ticker, curr_date)
rs = calculate_relative_strength_metrics(ticker, curr_date)
sr = calculate_support_resistance_metrics(ticker, curr_date)
timeframe = calculate_timeframe_signals(momentum, rs)
momentum_score = momentum.get("momentum_score", 0.5)
volume_score = volume.get("volume_score", 0.5)
rs_score = rs.get("relative_strength_score", 0.5)
rr_score = sr.get("risk_reward_score", 0.5)
unified_score = calculate_unified_score(
momentum=momentum_score,
volume=volume_score,
rs=rs_score,
rr=rr_score,
weights=weights,
)
return QuantitativeMetrics(
momentum_score=momentum_score,
volume_score=volume_score,
relative_strength_score=rs_score,
risk_reward_score=rr_score,
rsi=momentum.get("rsi"),
macd=momentum.get("macd"),
macd_signal=momentum.get("macd_signal"),
macd_histogram=momentum.get("macd_histogram"),
price_vs_sma50=momentum.get("price_vs_sma50"),
price_vs_sma200=momentum.get("price_vs_sma200"),
ema10_direction=momentum.get("ema10_direction"),
volume_ratio=volume.get("volume_ratio"),
volume_trend=volume.get("volume_trend"),
dollar_volume=volume.get("dollar_volume"),
rs_vs_spy_5d=rs.get("rs_vs_spy_5d"),
rs_vs_spy_20d=rs.get("rs_vs_spy_20d"),
rs_vs_spy_60d=rs.get("rs_vs_spy_60d"),
rs_vs_sector=rs.get("rs_vs_sector"),
sector_etf=rs.get("sector_etf"),
support_level=sr.get("support_level"),
resistance_level=sr.get("resistance_level"),
atr=sr.get("atr"),
suggested_stop=sr.get("suggested_stop"),
reward_target=sr.get("reward_target"),
risk_reward_ratio=sr.get("risk_reward_ratio"),
timeframe_alignment=timeframe.get("timeframe_alignment"),
short_term_signal=timeframe.get("short_term_signal"),
medium_term_signal=timeframe.get("medium_term_signal"),
long_term_signal=timeframe.get("long_term_signal"),
signal_strength=timeframe.get("signal_strength"),
quantitative_score=unified_score,
)
except Exception as e:
logger.warning(
"Failed to calculate quantitative metrics for %s: %s", ticker, str(e)
)
return None
def _normalize_score(score: float, max_score: float) -> float:
if max_score == 0:
return 0.0
return min(1.0, score / max_score)
def enhance_with_quantitative_scores(
stocks: list["TrendingStock"],
curr_date: str,
max_stocks: int = 50,
) -> list["TrendingStock"]:
if not stocks:
return []
settings = get_settings()
weights = settings.quantitative_weights
sorted_stocks = sorted(stocks, key=lambda s: s.score, reverse=True)
stocks_to_process = sorted_stocks[:max_stocks]
remaining_stocks = sorted_stocks[max_stocks:]
clear_run_cache()
results: dict[str, QuantitativeMetrics | None] = {}
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_ticker = {
executor.submit(
calculate_single_stock_metrics, stock.ticker, curr_date
): stock.ticker
for stock in stocks_to_process
}
for future in as_completed(future_to_ticker):
ticker = future_to_ticker[future]
try:
results[ticker] = future.result()
except Exception as e:
logger.warning("Quantitative scoring failed for %s: %s", ticker, str(e))
results[ticker] = None
max_news_score = max((s.score for s in stocks_to_process), default=1.0) or 1.0
for stock in stocks_to_process:
metrics = results.get(stock.ticker)
stock.quantitative_metrics = metrics
news_normalized = _normalize_score(stock.score, max_news_score)
if metrics is not None:
stock.conviction_score = (
weights.news_sentiment_weight * news_normalized
+ weights.quantitative_weight * metrics.quantitative_score
)
else:
stock.conviction_score = news_normalized * weights.news_sentiment_weight
for stock in remaining_stocks:
stock.quantitative_metrics = None
stock.conviction_score = None
enhanced_stocks = sorted(
stocks_to_process,
key=lambda s: s.conviction_score if s.conviction_score is not None else 0.0,
reverse=True,
)
clear_run_cache()
return enhanced_stocks + remaining_stocks

View File

@ -1,7 +1,7 @@
import os
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic_settings import BaseSettings
@ -12,6 +12,39 @@ class DataVendorsConfig(BaseModel):
news_data: str = "alpha_vantage"
class QuantitativeWeightsConfig(BaseModel):
news_sentiment_weight: float = Field(default=0.50, ge=0.0, le=1.0)
quantitative_weight: float = Field(default=0.50, ge=0.0, le=1.0)
momentum_weight: float = Field(default=0.30, ge=0.0, le=1.0)
volume_weight: float = Field(default=0.25, ge=0.0, le=1.0)
relative_strength_weight: float = Field(default=0.25, ge=0.0, le=1.0)
risk_reward_weight: float = Field(default=0.20, ge=0.0, le=1.0)
@model_validator(mode="after")
def validate_weights_sum(self) -> "QuantitativeWeightsConfig":
top_level_sum = self.news_sentiment_weight + self.quantitative_weight
if abs(top_level_sum - 1.0) > 0.01:
raise ValueError(
f"Top-level weights (news_sentiment_weight + quantitative_weight) "
f"must sum to 1.0, got {top_level_sum}"
)
sub_weights_sum = (
self.momentum_weight
+ self.volume_weight
+ self.relative_strength_weight
+ self.risk_reward_weight
)
if abs(sub_weights_sum - 1.0) > 0.01:
raise ValueError(
f"Sub-weights (momentum + volume + relative_strength + risk_reward) "
f"must sum to 1.0, got {sub_weights_sum}"
)
return self
class TradingAgentsSettings(BaseSettings):
project_dir: str = Field(
default_factory=lambda: os.path.abspath(
@ -58,6 +91,14 @@ class TradingAgentsSettings(BaseSettings):
data_vendors: DataVendorsConfig = Field(default_factory=DataVendorsConfig)
tool_vendors: dict[str, Any] = Field(default_factory=dict)
quantitative_weights: QuantitativeWeightsConfig = Field(
default_factory=QuantitativeWeightsConfig
)
quantitative_max_stocks: int = Field(default=50, ge=10, le=100)
quantitative_cache_ttl_intraday: int = Field(default=1, ge=1)
quantitative_cache_ttl_relative_strength: int = Field(default=4, ge=1)
min_dollar_volume: float = Field(default=1_000_000.0, ge=0.0)
model_config = {
"env_prefix": "TRADINGAGENTS_",
"env_nested_delimiter": "__",
@ -104,6 +145,7 @@ class TradingAgentsSettings(BaseSettings):
def to_dict(self) -> dict[str, Any]:
result = self.model_dump()
result["data_vendors"] = self.data_vendors.model_dump()
result["quantitative_weights"] = self.quantitative_weights.model_dump()
return result
def get_api_key(self, vendor: str) -> str | None:

View File

@ -132,6 +132,11 @@ DEFAULT_TTL_HOURS = {
"get_insider_sentiment": 24,
"get_insider_transactions": 24,
"get_bulk_news": 1,
"quant_indicators": 1,
"volume_analysis": 1,
"relative_strength": 4,
"support_resistance": 1,
"risk_reward": 1,
}

View File

@ -22,6 +22,9 @@ from tradingagents.agents.discovery import (
calculate_trending_scores,
extract_entities,
)
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
from tradingagents.agents.utils.agent_utils import (
get_balance_sheet,
get_cashflow,
@ -303,12 +306,17 @@ class TradingAgentsGraph:
)
hard_timeout = self.config.get("discovery_hard_timeout", 120)
enable_quantitative = self.config.get("enable_quantitative_filtering", True)
quantitative_max_stocks = self.config.get("quantitative_max_stocks", 50)
discovery_result = {"stocks": [], "error": None}
def run_discovery():
try:
articles = get_bulk_news(request.lookback_period)
if not articles:
discovery_result["error"] = "No articles returned from news sources"
return
mentions = extract_entities(articles, self.config)
@ -326,6 +334,19 @@ class TradingAgentsGraph:
min_mentions=min_mentions,
)
if enable_quantitative and trending_stocks:
curr_date = date.today().strftime("%Y-%m-%d")
logger.info(
"Enhancing %d stocks with quantitative scores (max: %d)",
len(trending_stocks),
quantitative_max_stocks,
)
trending_stocks = enhance_with_quantitative_scores(
trending_stocks,
curr_date,
max_stocks=quantitative_max_stocks,
)
discovery_result["stocks"] = trending_stocks
except (
ValueError,

View File

@ -5149,6 +5149,7 @@ dependencies = [
{ name = "requests" },
{ name = "rich" },
{ name = "setuptools" },
{ name = "sqlalchemy" },
{ name = "stockstats" },
{ name = "tqdm" },
{ name = "tushare" },
@ -5196,6 +5197,7 @@ requires-dist = [
{ name = "rich", specifier = ">=14.0.0" },
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.8.2" },
{ name = "setuptools", specifier = ">=80.9.0" },
{ name = "sqlalchemy", specifier = ">=2.0.0" },
{ name = "stockstats", specifier = ">=0.6.5" },
{ name = "tqdm", specifier = ">=4.67.1" },
{ name = "tushare", specifier = ">=1.4.21" },