feat: add quantitative scoring with multi-timeframe analysis and CLI enhancements

Add quantitative scoring pipeline for discovery with technical indicator analysis:
- Momentum, volume, relative strength, and risk/reward scoring
- Support/resistance level detection
- Gap analysis for price momentum signals
- Configurable caching to reduce API calls

Implement multi-timeframe signal analysis:
- Short-term (5/20 day), medium-term (20/50 day), and long-term (50/200 day) signals
- Timeframe alignment detection (aligned_bullish, aligned_bearish, mixed, neutral)
- Signal strength calculation based on indicator agreement

Enhance CLI discovery display:
- Color-coded conviction scores (green/yellow/red thresholds)
- Signal column showing timeframe alignment status
- News mentions count column

Update tests to support new quantitative filtering configuration.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-03 17:04:18 -05:00
parent fb1a66f5a6
commit 1ba647ae4a
31 changed files with 4441 additions and 24 deletions

View File

@ -153,6 +153,35 @@ def select_event_filter() -> list[EventCategory] | None:
return choices return choices
def _get_conviction_display(stock: TrendingStock) -> tuple[str, str]:
if stock.conviction_score is None:
return "-", "dim"
score = stock.conviction_score
if score >= 0.7:
return f"{score:.2f}", "bold green"
elif score >= 0.5:
return f"{score:.2f}", "yellow"
else:
return f"{score:.2f}", "red"
def _get_signal_display(stock: TrendingStock) -> str:
if stock.quantitative_metrics is None:
return "-"
alignment = stock.quantitative_metrics.timeframe_alignment
if alignment == "aligned_bullish":
return "[bold green]+++[/bold green]"
elif alignment == "aligned_bearish":
return "[bold red]---[/bold red]"
elif alignment == "mixed":
strength = stock.quantitative_metrics.signal_strength or 0.5
if strength > 0.5:
return "[yellow]++[/yellow]"
else:
return "[yellow]--[/yellow]"
return "[dim]~[/dim]"
def create_discovery_results_table(trending_stocks: list[TrendingStock]) -> Table: def create_discovery_results_table(trending_stocks: list[TrendingStock]) -> Table:
table = Table( table = Table(
show_header=True, show_header=True,
@ -164,11 +193,12 @@ def create_discovery_results_table(trending_stocks: list[TrendingStock]) -> Tabl
) )
table.add_column("Rank", style="cyan", justify="center", width=6) table.add_column("Rank", style="cyan", justify="center", width=6)
table.add_column("Ticker", style="bold yellow", justify="center", width=10) table.add_column("Ticker", style="bold yellow", justify="center", width=8)
table.add_column("Company", style="white", justify="left", width=25) table.add_column("Company", style="white", justify="left", width=20)
table.add_column("Score", style="green", justify="right", width=10) table.add_column("Conv.", justify="right", width=6)
table.add_column("Mentions", style="blue", justify="center", width=10) table.add_column("Signal", justify="center", width=7)
table.add_column("Event Type", style="magenta", justify="center", width=18) table.add_column("News", style="blue", justify="right", width=6)
table.add_column("Event Type", style="magenta", justify="center", width=15)
for rank, stock in enumerate(trending_stocks, 1): for rank, stock in enumerate(trending_stocks, 1):
if rank <= 3: if rank <= 3:
@ -178,20 +208,54 @@ def create_discovery_results_table(trending_stocks: list[TrendingStock]) -> Tabl
rank_display = str(rank) rank_display = str(rank)
ticker_display = stock.ticker ticker_display = stock.ticker
conviction_text, conviction_style = _get_conviction_display(stock)
signal_display = _get_signal_display(stock)
table.add_row( table.add_row(
rank_display, rank_display,
ticker_display, ticker_display,
stock.company_name[:25] stock.company_name[:20]
if len(stock.company_name) > 25 if len(stock.company_name) > 20
else stock.company_name, else stock.company_name,
f"{stock.score:.2f}", f"[{conviction_style}]{conviction_text}[/{conviction_style}]",
str(stock.mention_count), signal_display,
stock.event_type.value.replace("_", " ").title(), f"{stock.score:.1f}",
stock.event_type.value.replace("_", " ").title()[:15],
) )
return table return table
def _format_timeframe_signals(stock: TrendingStock) -> str:
if stock.quantitative_metrics is None:
return "[dim]No quantitative data available[/dim]"
qm = stock.quantitative_metrics
short_color = (
"green"
if qm.short_term_signal == "bullish"
else "red"
if qm.short_term_signal == "bearish"
else "yellow"
)
med_color = (
"green"
if qm.medium_term_signal == "bullish"
else "red"
if qm.medium_term_signal == "bearish"
else "yellow"
)
long_color = (
"green"
if qm.long_term_signal == "bullish"
else "red"
if qm.long_term_signal == "bearish"
else "yellow"
)
return f"[{short_color}]Short: {(qm.short_term_signal or 'N/A').upper()}[/{short_color}] | [{med_color}]Med: {(qm.medium_term_signal or 'N/A').upper()}[/{med_color}] | [{long_color}]Long: {(qm.long_term_signal or 'N/A').upper()}[/{long_color}]"
def create_stock_detail_panel(stock: TrendingStock, rank: int) -> Panel: def create_stock_detail_panel(stock: TrendingStock, rank: int) -> Panel:
sentiment_label = ( sentiment_label = (
"positive" "positive"
@ -208,14 +272,64 @@ def create_stock_detail_panel(stock: TrendingStock, rank: int) -> Panel:
else "yellow" else "yellow"
) )
conviction_text = (
f"{stock.conviction_score:.2f}" if stock.conviction_score is not None else "N/A"
)
conviction_color = (
"green"
if stock.conviction_score and stock.conviction_score >= 0.7
else "yellow"
if stock.conviction_score and stock.conviction_score >= 0.5
else "red"
)
content = f"""[bold]Rank #{rank}: {stock.ticker} - {stock.company_name}[/bold] content = f"""[bold]Rank #{rank}: {stock.ticker} - {stock.company_name}[/bold]
[cyan]Score:[/cyan] {stock.score:.2f} [cyan]Conviction Score:[/cyan] [{conviction_color}]{conviction_text}[/{conviction_color}]
[cyan]News Score:[/cyan] {stock.score:.2f}
[cyan]Sentiment:[/cyan] [{sentiment_color}]{stock.sentiment:.2f} ({sentiment_label})[/{sentiment_color}] [cyan]Sentiment:[/cyan] [{sentiment_color}]{stock.sentiment:.2f} ({sentiment_label})[/{sentiment_color}]
[cyan]Sector:[/cyan] {stock.sector.value.replace("_", " ").title()} [cyan]Sector:[/cyan] {stock.sector.value.replace("_", " ").title()}
[cyan]Event Type:[/cyan] {stock.event_type.value.replace("_", " ").title()} [cyan]Event Type:[/cyan] {stock.event_type.value.replace("_", " ").title()}
[cyan]Mentions:[/cyan] {stock.mention_count} [cyan]Mentions:[/cyan] {stock.mention_count}
[bold]Timeframe Signals:[/bold]
{_format_timeframe_signals(stock)}"""
if stock.quantitative_metrics is not None:
qm = stock.quantitative_metrics
alignment_color = (
"green"
if qm.timeframe_alignment == "aligned_bullish"
else "red"
if qm.timeframe_alignment == "aligned_bearish"
else "yellow"
)
content += f"""
[bold]Quantitative Metrics:[/bold]
[cyan]Timeframe Alignment:[/cyan] [{alignment_color}]{(qm.timeframe_alignment or 'N/A').replace('_', ' ').upper()}[/{alignment_color}]
[cyan]Momentum Score:[/cyan] {qm.momentum_score:.2f} [cyan]Volume Score:[/cyan] {qm.volume_score:.2f}
[cyan]Relative Strength:[/cyan] {qm.relative_strength_score:.2f} [cyan]Risk/Reward:[/cyan] {qm.risk_reward_score:.2f}"""
if qm.rsi is not None:
rsi_color = "green" if qm.rsi < 35 else "red" if qm.rsi > 65 else "yellow"
content += f"\n[cyan]RSI:[/cyan] [{rsi_color}]{qm.rsi:.1f}[/{rsi_color}]"
if qm.support_level is not None and qm.resistance_level is not None:
content += f" [cyan]Support:[/cyan] ${qm.support_level:.2f} [cyan]Resistance:[/cyan] ${qm.resistance_level:.2f}"
if qm.risk_reward_ratio is not None:
rr_color = (
"green"
if qm.risk_reward_ratio >= 2.0
else "yellow"
if qm.risk_reward_ratio >= 1.0
else "red"
)
content += f"\n[cyan]Risk/Reward Ratio:[/cyan] [{rr_color}]{qm.risk_reward_ratio:.2f}:1[/{rr_color}]"
content += f"""
[bold]News Summary:[/bold] [bold]News Summary:[/bold]
{stock.news_summary} {stock.news_summary}

View File

@ -67,7 +67,9 @@ class TestDiscoverTrendingReturnsDiscoveryResult:
"discovery_cache_ttl": 300, "discovery_cache_ttl": 300,
"discovery_max_results": 20, "discovery_max_results": 20,
"discovery_min_mentions": 2, "discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
} }
graph.db_enabled = False
result = graph.discover_trending() result = graph.discover_trending()
@ -118,7 +120,9 @@ class TestSectorFilterParameter:
"discovery_cache_ttl": 300, "discovery_cache_ttl": 300,
"discovery_max_results": 20, "discovery_max_results": 20,
"discovery_min_mentions": 2, "discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
} }
graph.db_enabled = False
request = DiscoveryRequest( request = DiscoveryRequest(
lookback_period="24h", lookback_period="24h",
@ -162,7 +166,9 @@ class TestEventFilterParameter:
"discovery_cache_ttl": 300, "discovery_cache_ttl": 300,
"discovery_max_results": 20, "discovery_max_results": 20,
"discovery_min_mentions": 2, "discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
} }
graph.db_enabled = False
request = DiscoveryRequest( request = DiscoveryRequest(
lookback_period="24h", lookback_period="24h",

View File

@ -124,8 +124,9 @@ class TestResultsTableDisplay:
"Rank", "Rank",
"Ticker", "Ticker",
"Company", "Company",
"Score", "Conv.",
"Mentions", "Signal",
"News",
"Event Type", "Event Type",
] ]
for expected in expected_columns: for expected in expected_columns:

View File

@ -89,7 +89,9 @@ class TestEndToEndDiscoveryFlow:
"discovery_cache_ttl": 300, "discovery_cache_ttl": 300,
"discovery_max_results": 20, "discovery_max_results": 20,
"discovery_min_mentions": 2, "discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
} }
graph.db_enabled = False
request = DiscoveryRequest(lookback_period="24h") request = DiscoveryRequest(lookback_period="24h")
result = graph.discover_trending(request) result = graph.discover_trending(request)
@ -259,7 +261,9 @@ class TestNoTrendingStocksFound:
"discovery_cache_ttl": 300, "discovery_cache_ttl": 300,
"discovery_max_results": 20, "discovery_max_results": 20,
"discovery_min_mentions": 2, "discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
} }
graph.db_enabled = False
result = graph.discover_trending() result = graph.discover_trending()
@ -274,7 +278,16 @@ class TestAllStocksFilteredOutBySectorFilter:
def test_all_stocks_filtered_out_by_sector_filter( def test_all_stocks_filtered_out_by_sector_filter(
self, mock_scores, mock_extract, mock_bulk_news self, mock_scores, mock_extract, mock_bulk_news
): ):
mock_bulk_news.return_value = [] mock_bulk_news.return_value = [
NewsArticle(
title="Test article",
source="Test",
url="https://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
mock_extract.return_value = [] mock_extract.return_value = []
mock_scores.return_value = [ mock_scores.return_value = [
TrendingStock( TrendingStock(
@ -311,7 +324,9 @@ class TestAllStocksFilteredOutBySectorFilter:
"discovery_cache_ttl": 300, "discovery_cache_ttl": 300,
"discovery_max_results": 20, "discovery_max_results": 20,
"discovery_min_mentions": 2, "discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
} }
graph.db_enabled = False
request = DiscoveryRequest( request = DiscoveryRequest(
lookback_period="24h", lookback_period="24h",
@ -330,7 +345,16 @@ class TestAllStocksFilteredOutByEventFilter:
def test_all_stocks_filtered_out_by_event_filter( def test_all_stocks_filtered_out_by_event_filter(
self, mock_scores, mock_extract, mock_bulk_news self, mock_scores, mock_extract, mock_bulk_news
): ):
mock_bulk_news.return_value = [] mock_bulk_news.return_value = [
NewsArticle(
title="Test article",
source="Test",
url="https://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
mock_extract.return_value = [] mock_extract.return_value = []
mock_scores.return_value = [ mock_scores.return_value = [
TrendingStock( TrendingStock(
@ -356,7 +380,9 @@ class TestAllStocksFilteredOutByEventFilter:
"discovery_cache_ttl": 300, "discovery_cache_ttl": 300,
"discovery_max_results": 20, "discovery_max_results": 20,
"discovery_min_mentions": 2, "discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
} }
graph.db_enabled = False
request = DiscoveryRequest( request = DiscoveryRequest(
lookback_period="24h", lookback_period="24h",
@ -375,7 +401,16 @@ class TestMultipleSectorsAndEventsFiltering:
def test_combined_sector_and_event_filtering( def test_combined_sector_and_event_filtering(
self, mock_scores, mock_extract, mock_bulk_news self, mock_scores, mock_extract, mock_bulk_news
): ):
mock_bulk_news.return_value = [] mock_bulk_news.return_value = [
NewsArticle(
title="Test article",
source="Test",
url="https://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
mock_extract.return_value = [] mock_extract.return_value = []
mock_scores.return_value = [ mock_scores.return_value = [
TrendingStock( TrendingStock(
@ -423,7 +458,9 @@ class TestMultipleSectorsAndEventsFiltering:
"discovery_cache_ttl": 300, "discovery_cache_ttl": 300,
"discovery_max_results": 20, "discovery_max_results": 20,
"discovery_min_mentions": 2, "discovery_min_mentions": 2,
"enable_quantitative_filtering": False,
} }
graph.db_enabled = False
request = DiscoveryRequest( request = DiscoveryRequest(
lookback_period="24h", lookback_period="24h",

View File

@ -0,0 +1,197 @@
from unittest.mock import MagicMock, patch
import pytest
class TestCalculateRsiScore:
def test_rsi_oversold_bullish_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_rsi_score,
)
score = calculate_rsi_score(25.0)
assert 0.8 <= score <= 1.0
score = calculate_rsi_score(30.0)
assert 0.7 <= score <= 0.9
def test_rsi_neutral_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_rsi_score,
)
score = calculate_rsi_score(50.0)
assert 0.5 <= score <= 0.7
score = calculate_rsi_score(40.0)
assert 0.5 <= score <= 0.7
def test_rsi_overbought_warning_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_rsi_score,
)
score = calculate_rsi_score(70.0)
assert 0.3 <= score <= 0.5
score = calculate_rsi_score(75.0)
assert 0.2 <= score <= 0.5
def test_rsi_extreme_overbought_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_rsi_score,
)
score = calculate_rsi_score(85.0)
assert 0.0 <= score <= 0.3
score = calculate_rsi_score(95.0)
assert 0.0 <= score <= 0.2
class TestCalculateMacdScore:
def test_macd_bullish_crossover_high_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_macd_score,
)
score = calculate_macd_score(macd=1.5, signal=1.0, histogram=0.5)
assert 0.7 <= score <= 1.0
def test_macd_bearish_crossover_low_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_macd_score,
)
score = calculate_macd_score(macd=-1.5, signal=-1.0, histogram=-0.5)
assert 0.0 <= score <= 0.4
def test_macd_neutral_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_macd_score,
)
score = calculate_macd_score(macd=0.1, signal=0.1, histogram=0.0)
assert 0.4 <= score <= 0.6
def test_macd_expanding_histogram_bullish(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_macd_score,
)
score = calculate_macd_score(macd=2.0, signal=1.5, histogram=0.8)
assert 0.6 <= score <= 1.0
class TestCalculateSmaScore:
def test_price_above_both_smas_high_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_sma_score,
)
score = calculate_sma_score(price=150.0, sma50=140.0, sma200=130.0)
assert 0.7 <= score <= 1.0
def test_price_below_both_smas_low_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_sma_score,
)
score = calculate_sma_score(price=120.0, sma50=140.0, sma200=150.0)
assert 0.0 <= score <= 0.4
def test_price_between_smas_moderate_score(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_sma_score,
)
score = calculate_sma_score(price=145.0, sma50=150.0, sma200=130.0)
assert 0.4 <= score <= 0.7
def test_golden_alignment_bonus(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_sma_score,
)
score = calculate_sma_score(price=160.0, sma50=150.0, sma200=140.0)
assert score >= 0.8
class TestCalculateEmaDirection:
def test_ema_upward_trend(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_ema_direction,
)
ema_values = [100.0, 102.0, 104.0, 106.0, 108.0]
direction = calculate_ema_direction(ema_values)
assert direction == "up"
def test_ema_downward_trend(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_ema_direction,
)
ema_values = [108.0, 106.0, 104.0, 102.0, 100.0]
direction = calculate_ema_direction(ema_values)
assert direction == "down"
def test_ema_flat_trend(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_ema_direction,
)
ema_values = [100.0, 100.1, 99.9, 100.0, 100.1]
direction = calculate_ema_direction(ema_values)
assert direction == "flat"
def test_ema_empty_list_returns_flat(self):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_ema_direction,
)
ema_values = []
direction = calculate_ema_direction(ema_values)
assert direction == "flat"
class TestCalculateMomentumScore:
@patch("tradingagents.agents.discovery.indicators.momentum._get_stock_stats_bulk")
def test_calculate_momentum_score_returns_dict(self, mock_get_stats):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_momentum_score,
)
mock_get_stats.side_effect = lambda symbol, indicator, date: {
"2024-01-15": "50.0" if indicator == "rsi" else "1.0",
"2024-01-14": "49.0" if indicator == "rsi" else "0.9",
"2024-01-13": "48.0" if indicator == "rsi" else "0.8",
"2024-01-12": "47.0" if indicator == "rsi" else "0.7",
"2024-01-11": "46.0" if indicator == "rsi" else "0.6",
}
result = calculate_momentum_score("AAPL", "2024-01-15")
assert isinstance(result, dict)
assert "rsi" in result
assert "macd" in result
assert "macd_signal" in result
assert "macd_histogram" in result
assert "price_vs_sma50" in result
assert "price_vs_sma200" in result
assert "ema10_direction" in result
assert "momentum_score" in result
assert 0.0 <= result["momentum_score"] <= 1.0
@patch("tradingagents.agents.discovery.indicators.momentum._get_stock_stats_bulk")
def test_calculate_momentum_score_handles_missing_data(self, mock_get_stats):
from tradingagents.agents.discovery.indicators.momentum import (
calculate_momentum_score,
)
mock_get_stats.return_value = {}
result = calculate_momentum_score("INVALID", "2024-01-15")
assert isinstance(result, dict)
assert result["momentum_score"] == 0.5

View File

@ -0,0 +1,344 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
class TestDiscoverTrendingIntegration:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
@patch("tradingagents.graph.trading_graph.enhance_with_quantitative_scores")
def test_discover_trending_calls_quantitative_enhancement(
self, mock_enhance, mock_scores, mock_extract, mock_bulk_news
):
from tradingagents.agents.discovery.models import (
EventCategory,
Sector,
TrendingStock,
)
from tradingagents.dataflows.models import NewsArticle
from tradingagents.graph.trading_graph import TradingAgentsGraph
mock_bulk_news.return_value = [
NewsArticle(
title="Test Article",
source="Test",
url="http://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
mock_extract.return_value = []
mock_stock = TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=85.0,
mention_count=10,
sentiment=0.7,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Test",
source_articles=[],
)
mock_scores.return_value = [mock_stock]
mock_enhance.return_value = [mock_stock]
config = {
"llm_provider": "openai",
"quick_think_llm": "gpt-4o-mini",
"deep_think_llm": "gpt-4o",
"backend_url": "https://api.openai.com/v1",
"project_dir": "/tmp/test",
"database_enabled": False,
"enable_quantitative_filtering": True,
}
with (
patch("tradingagents.graph.trading_graph.ChatOpenAI"),
patch("tradingagents.graph.trading_graph.FinancialSituationMemory"),
patch("tradingagents.graph.trading_graph.GraphSetup"),
):
graph = TradingAgentsGraph(config=config)
result = graph.discover_trending()
mock_enhance.assert_called_once()
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
@patch("tradingagents.graph.trading_graph.enhance_with_quantitative_scores")
def test_discover_trending_skips_quantitative_when_disabled(
self, mock_enhance, mock_scores, mock_extract, mock_bulk_news
):
from tradingagents.agents.discovery.models import (
EventCategory,
Sector,
TrendingStock,
)
from tradingagents.dataflows.models import NewsArticle
from tradingagents.graph.trading_graph import TradingAgentsGraph
mock_bulk_news.return_value = [
NewsArticle(
title="Test Article",
source="Test",
url="http://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
mock_extract.return_value = []
mock_stock = TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=85.0,
mention_count=10,
sentiment=0.7,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Test",
source_articles=[],
)
mock_scores.return_value = [mock_stock]
config = {
"llm_provider": "openai",
"quick_think_llm": "gpt-4o-mini",
"deep_think_llm": "gpt-4o",
"backend_url": "https://api.openai.com/v1",
"project_dir": "/tmp/test",
"database_enabled": False,
"enable_quantitative_filtering": False,
}
with (
patch("tradingagents.graph.trading_graph.ChatOpenAI"),
patch("tradingagents.graph.trading_graph.FinancialSituationMemory"),
patch("tradingagents.graph.trading_graph.GraphSetup"),
):
graph = TradingAgentsGraph(config=config)
result = graph.discover_trending()
mock_enhance.assert_not_called()
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
@patch("tradingagents.graph.trading_graph.enhance_with_quantitative_scores")
def test_discover_trending_uses_config_max_stocks(
self, mock_enhance, mock_scores, mock_extract, mock_bulk_news
):
from tradingagents.agents.discovery.models import (
EventCategory,
Sector,
TrendingStock,
)
from tradingagents.dataflows.models import NewsArticle
from tradingagents.graph.trading_graph import TradingAgentsGraph
mock_bulk_news.return_value = [
NewsArticle(
title="Test Article",
source="Test",
url="http://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
mock_extract.return_value = []
mock_stocks = [
TrendingStock(
ticker=f"TICK{i}",
company_name=f"Company {i}",
score=100.0 - i,
mention_count=10,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.OTHER,
news_summary="Test",
source_articles=[],
)
for i in range(30)
]
mock_scores.return_value = mock_stocks
mock_enhance.return_value = mock_stocks
config = {
"llm_provider": "openai",
"quick_think_llm": "gpt-4o-mini",
"deep_think_llm": "gpt-4o",
"backend_url": "https://api.openai.com/v1",
"project_dir": "/tmp/test",
"database_enabled": False,
"enable_quantitative_filtering": True,
"quantitative_max_stocks": 25,
}
with (
patch("tradingagents.graph.trading_graph.ChatOpenAI"),
patch("tradingagents.graph.trading_graph.FinancialSituationMemory"),
patch("tradingagents.graph.trading_graph.GraphSetup"),
):
graph = TradingAgentsGraph(config=config)
result = graph.discover_trending()
call_args = mock_enhance.call_args
assert call_args[1].get("max_stocks", 50) == 25
class TestScorerConvictionSupport:
def test_calculate_trending_scores_preserves_original_score(self):
from tradingagents.agents.discovery.entity_extractor import EntityMention
from tradingagents.agents.discovery.models import (
EventCategory,
NewsArticle,
)
from tradingagents.agents.discovery.scorer import calculate_trending_scores
mentions = [
EntityMention(
company_name="Apple",
confidence=0.9,
sentiment=0.7,
event_type=EventCategory.EARNINGS,
context_snippet="Apple reports strong earnings",
article_id="article_0",
),
EntityMention(
company_name="Apple",
confidence=0.85,
sentiment=0.6,
event_type=EventCategory.EARNINGS,
context_snippet="Apple stock rises",
article_id="article_1",
),
]
articles = [
NewsArticle(
title="Article 1",
source="Test",
url="http://test.com/1",
published_at=datetime.now(),
content_snippet="Test content 1",
ticker_mentions=["AAPL"],
),
NewsArticle(
title="Article 2",
source="Test",
url="http://test.com/2",
published_at=datetime.now(),
content_snippet="Test content 2",
ticker_mentions=["AAPL"],
),
]
with patch(
"tradingagents.agents.discovery.scorer.resolve_ticker"
) as mock_resolve:
mock_resolve.return_value = "AAPL"
result = calculate_trending_scores(mentions, articles)
assert len(result) == 1
assert result[0].ticker == "AAPL"
assert result[0].score > 0
assert result[0].conviction_score is None
def test_trending_stock_supports_conviction_score(self):
from tradingagents.agents.discovery.models import (
EventCategory,
Sector,
TrendingStock,
)
from tradingagents.agents.discovery.quantitative_models import (
QuantitativeMetrics,
)
stock = TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=85.0,
mention_count=10,
sentiment=0.7,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Test",
source_articles=[],
)
assert stock.conviction_score is None
stock.conviction_score = 0.85
assert stock.conviction_score == 0.85
metrics = QuantitativeMetrics(
momentum_score=0.7,
volume_score=0.6,
relative_strength_score=0.65,
risk_reward_score=0.7,
quantitative_score=0.66,
)
stock.quantitative_metrics = metrics
assert stock.quantitative_metrics.quantitative_score == 0.66
class TestBackwardCompatibility:
def test_trending_stock_without_quantitative_fields(self):
from tradingagents.agents.discovery.models import (
EventCategory,
Sector,
TrendingStock,
)
stock = TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=85.0,
mention_count=10,
sentiment=0.7,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Test",
source_articles=[],
)
assert stock.quantitative_metrics is None
assert stock.conviction_score is None
stock_dict = stock.to_dict()
assert (
"quantitative_metrics" not in stock_dict
or stock_dict.get("quantitative_metrics") is None
)
assert (
"conviction_score" not in stock_dict
or stock_dict.get("conviction_score") is None
)
def test_trending_stock_from_dict_without_quantitative_fields(self):
from tradingagents.agents.discovery.models import TrendingStock
data = {
"ticker": "AAPL",
"company_name": "Apple Inc",
"score": 85.0,
"mention_count": 10,
"sentiment": 0.7,
"sector": "technology",
"event_type": "earnings",
"news_summary": "Test",
"source_articles": [],
}
stock = TrendingStock.from_dict(data)
assert stock.ticker == "AAPL"
assert stock.quantitative_metrics is None
assert stock.conviction_score is None

View File

@ -0,0 +1,101 @@
import time
from unittest.mock import MagicMock, patch
import pandas as pd
import pytest
class TestQuantitativeCache:
def test_cache_set_and_get(self):
from tradingagents.agents.discovery.quantitative_cache import (
clear_run_cache,
get_cached_price_data,
set_cached_price_data,
)
clear_run_cache()
df = pd.DataFrame(
{"Close": [100.0, 101.0, 102.0], "Volume": [1000, 1100, 1200]}
)
set_cached_price_data("AAPL", df)
cached = get_cached_price_data("AAPL")
assert cached is not None
assert len(cached) == 3
assert cached["Close"].tolist() == [100.0, 101.0, 102.0]
def test_cache_miss_returns_none(self):
from tradingagents.agents.discovery.quantitative_cache import (
clear_run_cache,
get_cached_price_data,
)
clear_run_cache()
result = get_cached_price_data("NONEXISTENT")
assert result is None
def test_cache_clear(self):
from tradingagents.agents.discovery.quantitative_cache import (
clear_run_cache,
get_cached_price_data,
set_cached_price_data,
)
clear_run_cache()
df = pd.DataFrame({"Close": [100.0]})
set_cached_price_data("AAPL", df)
assert get_cached_price_data("AAPL") is not None
clear_run_cache()
assert get_cached_price_data("AAPL") is None
def test_cache_max_size_enforcement(self):
from tradingagents.agents.discovery.quantitative_cache import (
MAX_CACHE_SIZE,
clear_run_cache,
get_cached_price_data,
set_cached_price_data,
)
clear_run_cache()
for i in range(MAX_CACHE_SIZE + 10):
ticker = f"TICKER{i}"
df = pd.DataFrame({"Close": [float(i)]})
set_cached_price_data(ticker, df)
cached_count = 0
for i in range(MAX_CACHE_SIZE + 10):
ticker = f"TICKER{i}"
if get_cached_price_data(ticker) is not None:
cached_count += 1
assert cached_count <= MAX_CACHE_SIZE
class TestCacheTTLConstants:
def test_default_ttl_hours_contains_quant_entries(self):
from tradingagents.database.services.market_data import DEFAULT_TTL_HOURS
assert "quant_indicators" in DEFAULT_TTL_HOURS
assert "volume_analysis" in DEFAULT_TTL_HOURS
assert "relative_strength" in DEFAULT_TTL_HOURS
assert "support_resistance" in DEFAULT_TTL_HOURS
assert "risk_reward" in DEFAULT_TTL_HOURS
def test_quant_ttl_values(self):
from tradingagents.database.services.market_data import DEFAULT_TTL_HOURS
assert DEFAULT_TTL_HOURS["quant_indicators"] == 1
assert DEFAULT_TTL_HOURS["volume_analysis"] == 1
assert DEFAULT_TTL_HOURS["relative_strength"] == 4
assert DEFAULT_TTL_HOURS["support_resistance"] == 1
assert DEFAULT_TTL_HOURS["risk_reward"] == 1

View File

@ -0,0 +1,124 @@
import os
from unittest.mock import patch
import pytest
from pydantic import ValidationError
from tradingagents.config import (
QuantitativeWeightsConfig,
TradingAgentsSettings,
get_settings,
reset_settings,
)
class TestQuantitativeWeightsConfigDefaults:
def test_default_weight_values_are_set_correctly(self):
config = QuantitativeWeightsConfig()
assert config.news_sentiment_weight == 0.50
assert config.quantitative_weight == 0.50
assert config.momentum_weight == 0.30
assert config.volume_weight == 0.25
assert config.relative_strength_weight == 0.25
assert config.risk_reward_weight == 0.20
class TestQuantitativeWeightsConfigValidation:
def test_top_level_weights_sum_to_one(self):
config = QuantitativeWeightsConfig(
news_sentiment_weight=0.60,
quantitative_weight=0.40,
)
assert (
config.news_sentiment_weight + config.quantitative_weight
== pytest.approx(1.0)
)
def test_sub_weights_sum_to_one(self):
config = QuantitativeWeightsConfig()
sub_weights_sum = (
config.momentum_weight
+ config.volume_weight
+ config.relative_strength_weight
+ config.risk_reward_weight
)
assert sub_weights_sum == pytest.approx(1.0)
def test_top_level_weights_validation_rejects_invalid_sum(self):
with pytest.raises(ValidationError) as exc_info:
QuantitativeWeightsConfig(
news_sentiment_weight=0.60,
quantitative_weight=0.60,
)
assert "sum to 1.0" in str(exc_info.value).lower()
def test_sub_weights_validation_rejects_invalid_sum(self):
with pytest.raises(ValidationError) as exc_info:
QuantitativeWeightsConfig(
momentum_weight=0.50,
volume_weight=0.50,
relative_strength_weight=0.50,
risk_reward_weight=0.50,
)
assert "sum to 1.0" in str(exc_info.value).lower()
class TestQuantitativeWeightsConfigEnvOverride:
def setup_method(self):
reset_settings()
def teardown_method(self):
reset_settings()
def test_environment_variable_override_functionality(self):
env_vars = {
"TRADINGAGENTS_QUANTITATIVE_WEIGHTS__NEWS_SENTIMENT_WEIGHT": "0.70",
"TRADINGAGENTS_QUANTITATIVE_WEIGHTS__QUANTITATIVE_WEIGHT": "0.30",
}
with patch.dict(os.environ, env_vars, clear=False):
reset_settings()
settings = get_settings()
assert settings.quantitative_weights.news_sentiment_weight == pytest.approx(
0.70
)
assert settings.quantitative_weights.quantitative_weight == pytest.approx(
0.30
)
class TestQuantitativeSettingsIntegration:
def setup_method(self):
reset_settings()
def teardown_method(self):
reset_settings()
def test_quantitative_settings_in_trading_agents_settings(self):
settings = TradingAgentsSettings()
assert hasattr(settings, "quantitative_weights")
assert isinstance(settings.quantitative_weights, QuantitativeWeightsConfig)
assert hasattr(settings, "quantitative_max_stocks")
assert hasattr(settings, "quantitative_cache_ttl_intraday")
assert hasattr(settings, "quantitative_cache_ttl_relative_strength")
assert hasattr(settings, "min_dollar_volume")
def test_quantitative_settings_default_values(self):
settings = TradingAgentsSettings()
assert settings.quantitative_max_stocks == 50
assert settings.quantitative_cache_ttl_intraday == 1
assert settings.quantitative_cache_ttl_relative_strength == 4
assert settings.min_dollar_volume == 1_000_000.0
def test_quantitative_max_stocks_bounds(self):
settings = TradingAgentsSettings(quantitative_max_stocks=75)
assert settings.quantitative_max_stocks == 75
with pytest.raises(ValidationError):
TradingAgentsSettings(quantitative_max_stocks=5)
with pytest.raises(ValidationError):
TradingAgentsSettings(quantitative_max_stocks=150)

View File

@ -0,0 +1,465 @@
import threading
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from unittest.mock import MagicMock, patch
import pandas as pd
import pytest
from tradingagents.agents.discovery.models import (
EventCategory,
NewsArticle,
Sector,
TrendingStock,
)
from tradingagents.agents.discovery.quantitative_models import QuantitativeMetrics
from tradingagents.config import QuantitativeWeightsConfig
class TestFullPipelineIntegration:
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_volume_metrics"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_relative_strength_metrics"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_support_resistance_metrics"
)
def test_full_pipeline_news_to_quantitative_enhancement(
self, mock_sr, mock_rs, mock_vol, mock_mom
):
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
mock_mom.return_value = {
"rsi": 45.0,
"macd": 0.5,
"macd_signal": 0.4,
"macd_histogram": 0.1,
"price_vs_sma50": 5.0,
"price_vs_sma200": 10.0,
"ema10_direction": "up",
"momentum_score": 0.7,
}
mock_vol.return_value = {
"volume_ratio": 1.8,
"volume_trend": "increasing",
"dollar_volume": 25000000.0,
"volume_score": 0.75,
}
mock_rs.return_value = {
"rs_vs_spy_5d": 3.0,
"rs_vs_spy_20d": 5.0,
"rs_vs_spy_60d": 8.0,
"rs_vs_sector": 2.5,
"sector_etf": "XLK",
"relative_strength_score": 0.8,
}
mock_sr.return_value = {
"support_level": 145.0,
"resistance_level": 175.0,
"atr": 3.0,
"suggested_stop": 140.5,
"reward_target": 175.0,
"risk_reward_ratio": 2.5,
"risk_reward_score": 0.85,
}
articles = [
NewsArticle(
title="Apple Reports Strong Quarter",
source="Reuters",
url="http://example.com/1",
published_at=datetime.now(),
content_snippet="Apple beats expectations",
ticker_mentions=["AAPL"],
)
]
stocks = [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=85.0,
mention_count=15,
sentiment=0.75,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Apple reports strong quarterly results",
source_articles=articles,
),
TrendingStock(
ticker="MSFT",
company_name="Microsoft Corp",
score=72.0,
mention_count=10,
sentiment=0.6,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.PRODUCT_LAUNCH,
news_summary="Microsoft launches new product",
source_articles=[],
),
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15")
assert len(result) == 2
for stock in result:
if stock.quantitative_metrics is not None:
assert stock.conviction_score is not None
assert 0.0 <= stock.conviction_score <= 1.0
assert stock.quantitative_metrics.rsi == 45.0
assert stock.quantitative_metrics.sector_etf == "XLK"
class TestDelistedStockHandling:
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score"
)
def test_stock_with_no_trading_data_returns_none(self, mock_mom):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_single_stock_metrics,
)
mock_mom.side_effect = Exception("No price data available for delisted stock")
result = calculate_single_stock_metrics("DELIST", "2024-01-15")
assert result is None
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_volume_metrics"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_relative_strength_metrics"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_support_resistance_metrics"
)
def test_enhance_continues_after_delisted_stock_failure(
self, mock_sr, mock_rs, mock_vol, mock_mom
):
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
def mom_side_effect(ticker, date):
if ticker == "DELIST":
raise Exception("No data for delisted stock")
return {"momentum_score": 0.6, "rsi": 50.0}
mock_mom.side_effect = mom_side_effect
mock_vol.return_value = {"volume_score": 0.5}
mock_rs.return_value = {"relative_strength_score": 0.5, "sector_etf": "SPY"}
mock_sr.return_value = {"risk_reward_score": 0.5}
stocks = [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=90.0,
mention_count=10,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.OTHER,
news_summary="Test",
source_articles=[],
),
TrendingStock(
ticker="DELIST",
company_name="Delisted Corp",
score=80.0,
mention_count=8,
sentiment=0.4,
sector=Sector.OTHER,
event_type=EventCategory.OTHER,
news_summary="Test",
source_articles=[],
),
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15")
assert len(result) == 2
aapl = next((s for s in result if s.ticker == "AAPL"), None)
delist = next((s for s in result if s.ticker == "DELIST"), None)
assert aapl is not None
assert aapl.quantitative_metrics is not None
assert delist is not None
assert delist.quantitative_metrics is None
class TestWeightConfigurationEdgeCases:
def test_weights_near_boundary_accepted(self):
config = QuantitativeWeightsConfig(
news_sentiment_weight=0.501,
quantitative_weight=0.499,
)
total = config.news_sentiment_weight + config.quantitative_weight
assert abs(total - 1.0) < 0.01
def test_sub_weights_custom_values_accepted(self):
config = QuantitativeWeightsConfig(
momentum_weight=0.40,
volume_weight=0.20,
relative_strength_weight=0.20,
risk_reward_weight=0.20,
)
sub_total = (
config.momentum_weight
+ config.volume_weight
+ config.relative_strength_weight
+ config.risk_reward_weight
)
assert abs(sub_total - 1.0) < 0.01
class TestCacheConcurrentAccess:
def test_cache_thread_safety_under_concurrent_writes(self):
from tradingagents.agents.discovery.quantitative_cache import (
clear_run_cache,
get_cached_price_data,
set_cached_price_data,
)
clear_run_cache()
errors = []
results = {}
def write_to_cache(ticker_num):
try:
ticker = f"TICK{ticker_num}"
df = pd.DataFrame({"Close": [float(ticker_num)]})
set_cached_price_data(ticker, df)
cached = get_cached_price_data(ticker)
if cached is not None:
results[ticker] = cached["Close"].iloc[0]
except Exception as e:
errors.append(str(e))
with ThreadPoolExecutor(max_workers=20) as executor:
futures = [executor.submit(write_to_cache, i) for i in range(50)]
for f in futures:
f.result()
assert len(errors) == 0
clear_run_cache()
class TestConvictionScoreRanking:
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics"
)
def test_conviction_score_ranking_accuracy(self, mock_calc):
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
def side_effect(ticker, date):
scores = {
"HIGH_BOTH": (0.9, 0.9),
"HIGH_NEWS_LOW_QUANT": (0.3, 0.9),
"LOW_NEWS_HIGH_QUANT": (0.9, 0.3),
"LOW_BOTH": (0.3, 0.3),
}
quant, _ = scores.get(ticker, (0.5, 0.5))
return QuantitativeMetrics(
momentum_score=quant,
volume_score=quant,
relative_strength_score=quant,
risk_reward_score=quant,
quantitative_score=quant,
)
mock_calc.side_effect = side_effect
stocks = [
TrendingStock(
ticker="LOW_BOTH",
company_name="Low Both",
score=30.0,
mention_count=5,
sentiment=0.3,
sector=Sector.OTHER,
event_type=EventCategory.OTHER,
news_summary="Test",
source_articles=[],
),
TrendingStock(
ticker="HIGH_BOTH",
company_name="High Both",
score=90.0,
mention_count=15,
sentiment=0.9,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Test",
source_articles=[],
),
TrendingStock(
ticker="HIGH_NEWS_LOW_QUANT",
company_name="High News Low Quant",
score=90.0,
mention_count=15,
sentiment=0.9,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.OTHER,
news_summary="Test",
source_articles=[],
),
TrendingStock(
ticker="LOW_NEWS_HIGH_QUANT",
company_name="Low News High Quant",
score=30.0,
mention_count=5,
sentiment=0.3,
sector=Sector.OTHER,
event_type=EventCategory.OTHER,
news_summary="Test",
source_articles=[],
),
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15")
assert result[0].ticker == "HIGH_BOTH"
high_both = next(s for s in result if s.ticker == "HIGH_BOTH")
low_both = next(s for s in result if s.ticker == "LOW_BOTH")
assert high_both.conviction_score > low_both.conviction_score
class TestTrendingStockSerializationWithQuantitativeMetrics:
def test_trending_stock_with_quantitative_metrics_roundtrip(self):
articles = [
NewsArticle(
title="Test Article",
source="Test Source",
url="http://test.com",
published_at=datetime(2024, 1, 15, 10, 30, 0),
content_snippet="Test content",
ticker_mentions=["AAPL"],
)
]
metrics = QuantitativeMetrics(
momentum_score=0.75,
volume_score=0.65,
relative_strength_score=0.80,
risk_reward_score=0.70,
rsi=42.5,
macd=0.35,
macd_signal=0.28,
macd_histogram=0.07,
price_vs_sma50=4.5,
price_vs_sma200=9.2,
ema10_direction="up",
volume_ratio=1.65,
volume_trend="increasing",
dollar_volume=18500000.0,
rs_vs_spy_5d=2.8,
rs_vs_spy_20d=4.2,
rs_vs_spy_60d=7.5,
rs_vs_sector=2.1,
sector_etf="XLK",
support_level=148.50,
resistance_level=165.75,
atr=2.85,
suggested_stop=144.22,
reward_target=165.75,
risk_reward_ratio=2.65,
quantitative_score=0.725,
)
original = TrendingStock(
ticker="AAPL",
company_name="Apple Inc",
score=87.5,
mention_count=12,
sentiment=0.72,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Apple reports strong quarterly earnings",
source_articles=articles,
quantitative_metrics=metrics,
conviction_score=0.815,
)
data = original.to_dict()
restored = TrendingStock.from_dict(data)
assert restored.ticker == original.ticker
assert restored.score == original.score
assert restored.conviction_score == original.conviction_score
assert restored.quantitative_metrics is not None
assert restored.quantitative_metrics.momentum_score == 0.75
assert restored.quantitative_metrics.rsi == 42.5
assert restored.quantitative_metrics.sector_etf == "XLK"
assert restored.quantitative_metrics.quantitative_score == 0.725
class TestModuleIntegration:
def test_unified_score_combines_all_indicators_correctly(self):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_unified_score,
)
weights = QuantitativeWeightsConfig()
score = calculate_unified_score(
momentum=0.8,
volume=0.6,
rs=0.7,
rr=0.9,
weights=weights,
)
expected = (
0.8 * weights.momentum_weight
+ 0.6 * weights.volume_weight
+ 0.7 * weights.relative_strength_weight
+ 0.9 * weights.risk_reward_weight
)
assert abs(score - expected) < 0.001
def test_unified_score_clamped_to_valid_range(self):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_unified_score,
)
weights = QuantitativeWeightsConfig()
score = calculate_unified_score(
momentum=1.5,
volume=1.5,
rs=1.5,
rr=1.5,
weights=weights,
)
assert score == 1.0
score_low = calculate_unified_score(
momentum=-0.5,
volume=-0.5,
rs=-0.5,
rr=-0.5,
weights=weights,
)
assert score_low == 0.0

View File

@ -0,0 +1,267 @@
import pytest
from pydantic import ValidationError
from tradingagents.agents.discovery.quantitative_models import QuantitativeMetrics
class TestQuantitativeMetricsInstantiation:
def test_model_instantiation_with_valid_data(self):
metrics = QuantitativeMetrics(
momentum_score=0.75,
volume_score=0.60,
relative_strength_score=0.80,
risk_reward_score=0.70,
rsi=45.5,
macd=0.25,
macd_signal=0.20,
macd_histogram=0.05,
price_vs_sma50=5.2,
price_vs_sma200=12.3,
ema10_direction="up",
volume_ratio=1.8,
volume_trend="increasing",
dollar_volume=15_000_000.0,
rs_vs_spy_5d=2.5,
rs_vs_spy_20d=5.0,
rs_vs_spy_60d=8.2,
rs_vs_sector=3.1,
sector_etf="XLK",
support_level=145.50,
resistance_level=162.30,
atr=3.25,
suggested_stop=140.62,
reward_target=162.30,
risk_reward_ratio=2.8,
quantitative_score=0.72,
)
assert metrics.momentum_score == 0.75
assert metrics.volume_score == 0.60
assert metrics.relative_strength_score == 0.80
assert metrics.risk_reward_score == 0.70
assert metrics.rsi == 45.5
assert metrics.macd == 0.25
assert metrics.ema10_direction == "up"
assert metrics.sector_etf == "XLK"
assert metrics.quantitative_score == 0.72
class TestQuantitativeMetricsScoreValidation:
def test_score_validation_accepts_valid_range(self):
metrics = QuantitativeMetrics(
momentum_score=0.0,
volume_score=0.5,
relative_strength_score=1.0,
risk_reward_score=0.99,
rsi=50.0,
macd=0.0,
macd_signal=0.0,
macd_histogram=0.0,
price_vs_sma50=0.0,
price_vs_sma200=0.0,
ema10_direction="flat",
volume_ratio=1.0,
volume_trend="flat",
dollar_volume=1_000_000.0,
rs_vs_spy_5d=0.0,
rs_vs_spy_20d=0.0,
rs_vs_spy_60d=0.0,
rs_vs_sector=0.0,
sector_etf="SPY",
support_level=100.0,
resistance_level=110.0,
atr=2.0,
suggested_stop=97.0,
reward_target=110.0,
risk_reward_ratio=3.33,
quantitative_score=0.5,
)
assert metrics.momentum_score == 0.0
assert metrics.relative_strength_score == 1.0
assert metrics.quantitative_score == 0.5
def test_score_validation_rejects_negative_score(self):
with pytest.raises(ValidationError) as exc_info:
QuantitativeMetrics(
momentum_score=-0.1,
volume_score=0.5,
relative_strength_score=0.5,
risk_reward_score=0.5,
rsi=50.0,
macd=0.0,
macd_signal=0.0,
macd_histogram=0.0,
price_vs_sma50=0.0,
price_vs_sma200=0.0,
ema10_direction="flat",
volume_ratio=1.0,
volume_trend="flat",
dollar_volume=1_000_000.0,
rs_vs_spy_5d=0.0,
rs_vs_spy_20d=0.0,
rs_vs_spy_60d=0.0,
rs_vs_sector=0.0,
sector_etf="SPY",
support_level=100.0,
resistance_level=110.0,
atr=2.0,
suggested_stop=97.0,
reward_target=110.0,
risk_reward_ratio=3.33,
quantitative_score=0.5,
)
assert "momentum_score" in str(exc_info.value)
def test_score_validation_rejects_above_one(self):
with pytest.raises(ValidationError) as exc_info:
QuantitativeMetrics(
momentum_score=0.5,
volume_score=1.5,
relative_strength_score=0.5,
risk_reward_score=0.5,
rsi=50.0,
macd=0.0,
macd_signal=0.0,
macd_histogram=0.0,
price_vs_sma50=0.0,
price_vs_sma200=0.0,
ema10_direction="flat",
volume_ratio=1.0,
volume_trend="flat",
dollar_volume=1_000_000.0,
rs_vs_spy_5d=0.0,
rs_vs_spy_20d=0.0,
rs_vs_spy_60d=0.0,
rs_vs_sector=0.0,
sector_etf="SPY",
support_level=100.0,
resistance_level=110.0,
atr=2.0,
suggested_stop=97.0,
reward_target=110.0,
risk_reward_ratio=3.33,
quantitative_score=0.5,
)
assert "volume_score" in str(exc_info.value)
class TestQuantitativeMetricsSerialization:
def test_to_dict_and_from_dict_roundtrip(self):
original = QuantitativeMetrics(
momentum_score=0.75,
volume_score=0.60,
relative_strength_score=0.80,
risk_reward_score=0.70,
rsi=45.5,
macd=0.25,
macd_signal=0.20,
macd_histogram=0.05,
price_vs_sma50=5.2,
price_vs_sma200=12.3,
ema10_direction="up",
volume_ratio=1.8,
volume_trend="increasing",
dollar_volume=15_000_000.0,
rs_vs_spy_5d=2.5,
rs_vs_spy_20d=5.0,
rs_vs_spy_60d=8.2,
rs_vs_sector=3.1,
sector_etf="XLK",
support_level=145.50,
resistance_level=162.30,
atr=3.25,
suggested_stop=140.62,
reward_target=162.30,
risk_reward_ratio=2.8,
quantitative_score=0.72,
)
data = original.to_dict()
restored = QuantitativeMetrics.from_dict(data)
assert restored.momentum_score == original.momentum_score
assert restored.volume_score == original.volume_score
assert restored.relative_strength_score == original.relative_strength_score
assert restored.risk_reward_score == original.risk_reward_score
assert restored.rsi == original.rsi
assert restored.macd == original.macd
assert restored.ema10_direction == original.ema10_direction
assert restored.sector_etf == original.sector_etf
assert restored.quantitative_score == original.quantitative_score
class TestQuantitativeMetricsOptionalFields:
def test_optional_field_handling_with_none_defaults(self):
metrics = QuantitativeMetrics(
momentum_score=0.5,
volume_score=0.5,
relative_strength_score=0.5,
risk_reward_score=0.5,
rsi=None,
macd=None,
macd_signal=None,
macd_histogram=None,
price_vs_sma50=None,
price_vs_sma200=None,
ema10_direction=None,
volume_ratio=None,
volume_trend=None,
dollar_volume=None,
rs_vs_spy_5d=None,
rs_vs_spy_20d=None,
rs_vs_spy_60d=None,
rs_vs_sector=None,
sector_etf=None,
support_level=None,
resistance_level=None,
atr=None,
suggested_stop=None,
reward_target=None,
risk_reward_ratio=None,
quantitative_score=0.5,
)
assert metrics.rsi is None
assert metrics.macd is None
assert metrics.ema10_direction is None
assert metrics.sector_etf is None
assert metrics.momentum_score == 0.5
assert metrics.quantitative_score == 0.5
def test_serialization_with_none_values(self):
original = QuantitativeMetrics(
momentum_score=0.5,
volume_score=0.5,
relative_strength_score=0.5,
risk_reward_score=0.5,
rsi=None,
macd=None,
macd_signal=None,
macd_histogram=None,
price_vs_sma50=None,
price_vs_sma200=None,
ema10_direction=None,
volume_ratio=None,
volume_trend=None,
dollar_volume=None,
rs_vs_spy_5d=None,
rs_vs_spy_20d=None,
rs_vs_spy_60d=None,
rs_vs_sector=None,
sector_etf=None,
support_level=None,
resistance_level=None,
atr=None,
suggested_stop=None,
reward_target=None,
risk_reward_ratio=None,
quantitative_score=0.5,
)
data = original.to_dict()
restored = QuantitativeMetrics.from_dict(data)
assert restored.rsi is None
assert restored.ema10_direction is None
assert restored.momentum_score == 0.5

View File

@ -0,0 +1,312 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
class TestCalculateSingleStockMetrics:
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_volume_metrics"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_relative_strength_metrics"
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_support_resistance_metrics"
)
def test_calculate_single_stock_metrics_success(
self, mock_sr, mock_rs, mock_vol, mock_mom
):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_single_stock_metrics,
)
mock_mom.return_value = {
"rsi": 45.0,
"macd": 0.5,
"macd_signal": 0.4,
"macd_histogram": 0.1,
"price_vs_sma50": 2.5,
"price_vs_sma200": 5.0,
"ema10_direction": "up",
"momentum_score": 0.65,
}
mock_vol.return_value = {
"volume_ratio": 1.5,
"volume_trend": "increasing",
"dollar_volume": 5000000.0,
"volume_score": 0.7,
}
mock_rs.return_value = {
"rs_vs_spy_5d": 2.0,
"rs_vs_spy_20d": 3.5,
"rs_vs_spy_60d": 5.0,
"rs_vs_sector": 2.5,
"sector_etf": "XLK",
"relative_strength_score": 0.6,
}
mock_sr.return_value = {
"support_level": 150.0,
"resistance_level": 180.0,
"atr": 3.5,
"suggested_stop": 145.0,
"reward_target": 180.0,
"risk_reward_ratio": 2.5,
"risk_reward_score": 0.75,
}
result = calculate_single_stock_metrics("AAPL", "2024-01-15")
assert result is not None
assert result.momentum_score == 0.65
assert result.volume_score == 0.7
assert result.relative_strength_score == 0.6
assert result.risk_reward_score == 0.75
assert 0.0 <= result.quantitative_score <= 1.0
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score"
)
def test_calculate_single_stock_metrics_failure(self, mock_mom):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_single_stock_metrics,
)
mock_mom.side_effect = Exception("API failure")
result = calculate_single_stock_metrics("INVALID", "2024-01-15")
assert result is None
class TestCalculateUnifiedScore:
def test_unified_score_basic(self):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_unified_score,
)
from tradingagents.config import QuantitativeWeightsConfig
weights = QuantitativeWeightsConfig()
score = calculate_unified_score(
momentum=0.8,
volume=0.6,
rs=0.7,
rr=0.5,
weights=weights,
)
expected = 0.8 * 0.30 + 0.6 * 0.25 + 0.7 * 0.25 + 0.5 * 0.20
assert abs(score - expected) < 0.001
def test_unified_score_all_ones(self):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_unified_score,
)
from tradingagents.config import QuantitativeWeightsConfig
weights = QuantitativeWeightsConfig()
score = calculate_unified_score(
momentum=1.0,
volume=1.0,
rs=1.0,
rr=1.0,
weights=weights,
)
assert abs(score - 1.0) < 0.001
def test_unified_score_all_zeros(self):
from tradingagents.agents.discovery.quantitative_scorer import (
calculate_unified_score,
)
from tradingagents.config import QuantitativeWeightsConfig
weights = QuantitativeWeightsConfig()
score = calculate_unified_score(
momentum=0.0,
volume=0.0,
rs=0.0,
rr=0.0,
weights=weights,
)
assert abs(score - 0.0) < 0.001
class TestEnhanceWithQuantitativeScores:
def _create_mock_trending_stock(self, ticker: str, score: float):
from tradingagents.agents.discovery.models import (
EventCategory,
NewsArticle,
Sector,
TrendingStock,
)
return TrendingStock(
ticker=ticker,
company_name=f"{ticker} Inc",
score=score,
mention_count=10,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Test summary",
source_articles=[],
)
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics"
)
def test_enhance_with_quantitative_scores_basic(self, mock_calc):
from tradingagents.agents.discovery.quantitative_models import (
QuantitativeMetrics,
)
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
mock_calc.return_value = QuantitativeMetrics(
momentum_score=0.7,
volume_score=0.6,
relative_strength_score=0.65,
risk_reward_score=0.7,
quantitative_score=0.66,
)
stocks = [
self._create_mock_trending_stock("AAPL", 90.0),
self._create_mock_trending_stock("MSFT", 80.0),
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15")
assert len(result) == 2
for stock in result:
assert stock.quantitative_metrics is not None
assert stock.conviction_score is not None
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics"
)
def test_enhance_caps_at_max_stocks(self, mock_calc):
from tradingagents.agents.discovery.quantitative_models import (
QuantitativeMetrics,
)
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
mock_calc.return_value = QuantitativeMetrics(
momentum_score=0.5,
volume_score=0.5,
relative_strength_score=0.5,
risk_reward_score=0.5,
quantitative_score=0.5,
)
stocks = [
self._create_mock_trending_stock(f"TICK{i}", 100.0 - i) for i in range(60)
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15", max_stocks=10)
assert mock_calc.call_count == 10
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics"
)
def test_enhance_handles_partial_failures(self, mock_calc):
from tradingagents.agents.discovery.quantitative_models import (
QuantitativeMetrics,
)
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
def side_effect(ticker, date):
if ticker == "FAIL":
return None
return QuantitativeMetrics(
momentum_score=0.5,
volume_score=0.5,
relative_strength_score=0.5,
risk_reward_score=0.5,
quantitative_score=0.5,
)
mock_calc.side_effect = side_effect
stocks = [
self._create_mock_trending_stock("AAPL", 90.0),
self._create_mock_trending_stock("FAIL", 85.0),
self._create_mock_trending_stock("MSFT", 80.0),
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15")
assert len(result) == 3
successful = [s for s in result if s.quantitative_metrics is not None]
assert len(successful) == 2
@patch(
"tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics"
)
def test_enhance_sorts_by_conviction_score(self, mock_calc):
from tradingagents.agents.discovery.quantitative_models import (
QuantitativeMetrics,
)
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
def side_effect(ticker, date):
scores = {
"LOW": 0.3,
"MED": 0.5,
"HIGH": 0.9,
}
quant_score = scores.get(ticker, 0.5)
return QuantitativeMetrics(
momentum_score=quant_score,
volume_score=quant_score,
relative_strength_score=quant_score,
risk_reward_score=quant_score,
quantitative_score=quant_score,
)
mock_calc.side_effect = side_effect
stocks = [
self._create_mock_trending_stock("LOW", 60.0),
self._create_mock_trending_stock("HIGH", 40.0),
self._create_mock_trending_stock("MED", 50.0),
]
result = enhance_with_quantitative_scores(stocks, "2024-01-15")
assert (
result[0].quantitative_metrics.quantitative_score
>= result[1].quantitative_metrics.quantitative_score
)
assert (
result[1].quantitative_metrics.quantitative_score
>= result[2].quantitative_metrics.quantitative_score
)
class TestErrorHandling:
def test_empty_stock_list(self):
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
result = enhance_with_quantitative_scores([], "2024-01-15")
assert result == []

View File

@ -0,0 +1,190 @@
from unittest.mock import MagicMock, patch
import pytest
class TestCalculateReturn:
def test_calculate_positive_return(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_return,
)
prices = [100.0, 102.0, 105.0, 108.0, 110.0]
ret = calculate_return(prices, days=5)
assert ret == pytest.approx(10.0, rel=0.01)
def test_calculate_negative_return(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_return,
)
prices = [100.0, 98.0, 95.0, 92.0, 90.0]
ret = calculate_return(prices, days=5)
assert ret == pytest.approx(-10.0, rel=0.01)
def test_calculate_return_insufficient_data(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_return,
)
prices = [100.0, 110.0]
ret = calculate_return(prices, days=5)
assert ret == pytest.approx(10.0, rel=0.01)
def test_calculate_return_empty_list(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_return,
)
prices = []
ret = calculate_return(prices, days=5)
assert ret == 0.0
class TestCalculateRelativeStrength:
def test_positive_outperformance(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_relative_strength,
)
rs = calculate_relative_strength(stock_return=15.0, benchmark_return=10.0)
assert rs == 5.0
def test_negative_outperformance(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_relative_strength,
)
rs = calculate_relative_strength(stock_return=5.0, benchmark_return=10.0)
assert rs == -5.0
def test_equal_returns(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_relative_strength,
)
rs = calculate_relative_strength(stock_return=10.0, benchmark_return=10.0)
assert rs == 0.0
class TestSectorEtfMap:
def test_sector_etf_map_contains_expected_sectors(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
SECTOR_ETF_MAP,
)
assert "technology" in SECTOR_ETF_MAP
assert "finance" in SECTOR_ETF_MAP
assert "healthcare" in SECTOR_ETF_MAP
assert "energy" in SECTOR_ETF_MAP
assert "consumer_goods" in SECTOR_ETF_MAP
assert "industrials" in SECTOR_ETF_MAP
assert "other" in SECTOR_ETF_MAP
def test_sector_etf_map_values(self):
from tradingagents.agents.discovery.indicators.relative_strength import (
SECTOR_ETF_MAP,
)
assert SECTOR_ETF_MAP["technology"] == "XLK"
assert SECTOR_ETF_MAP["finance"] == "XLF"
assert SECTOR_ETF_MAP["healthcare"] == "XLV"
assert SECTOR_ETF_MAP["energy"] == "XLE"
assert SECTOR_ETF_MAP["other"] == "SPY"
class TestGetSectorEtf:
@patch(
"tradingagents.agents.discovery.indicators.relative_strength.classify_sector"
)
def test_get_sector_etf_for_tech_stock(self, mock_classify):
from tradingagents.agents.discovery.indicators.relative_strength import (
get_sector_etf,
)
mock_classify.return_value = "technology"
etf = get_sector_etf("AAPL")
assert etf == "XLK"
@patch(
"tradingagents.agents.discovery.indicators.relative_strength.classify_sector"
)
def test_get_sector_etf_for_finance_stock(self, mock_classify):
from tradingagents.agents.discovery.indicators.relative_strength import (
get_sector_etf,
)
mock_classify.return_value = "finance"
etf = get_sector_etf("JPM")
assert etf == "XLF"
@patch(
"tradingagents.agents.discovery.indicators.relative_strength.classify_sector"
)
def test_get_sector_etf_for_unknown_sector(self, mock_classify):
from tradingagents.agents.discovery.indicators.relative_strength import (
get_sector_etf,
)
mock_classify.return_value = "other"
etf = get_sector_etf("XYZ")
assert etf == "SPY"
class TestCalculateRelativeStrengthMetrics:
@patch(
"tradingagents.agents.discovery.indicators.relative_strength._get_price_history"
)
@patch("tradingagents.agents.discovery.indicators.relative_strength.get_sector_etf")
def test_calculate_rs_metrics_returns_dict(
self, mock_sector_etf, mock_price_history
):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_relative_strength_metrics,
)
mock_sector_etf.return_value = "XLK"
base_prices = [100.0 + i * 0.5 for i in range(70)]
spy_prices = [100.0 + i * 0.3 for i in range(70)]
sector_prices = [100.0 + i * 0.4 for i in range(70)]
def price_history_side_effect(ticker, date, days):
if ticker == "AAPL":
return base_prices[-days:]
elif ticker == "SPY":
return spy_prices[-days:]
else:
return sector_prices[-days:]
mock_price_history.side_effect = price_history_side_effect
result = calculate_relative_strength_metrics("AAPL", "2024-01-15")
assert isinstance(result, dict)
assert "rs_vs_spy_5d" in result
assert "rs_vs_spy_20d" in result
assert "rs_vs_spy_60d" in result
assert "rs_vs_sector" in result
assert "sector_etf" in result
assert "relative_strength_score" in result
assert 0.0 <= result["relative_strength_score"] <= 1.0
@patch(
"tradingagents.agents.discovery.indicators.relative_strength._get_price_history"
)
@patch("tradingagents.agents.discovery.indicators.relative_strength.get_sector_etf")
def test_calculate_rs_metrics_handles_missing_benchmark_data(
self, mock_sector_etf, mock_price_history
):
from tradingagents.agents.discovery.indicators.relative_strength import (
calculate_relative_strength_metrics,
)
mock_sector_etf.return_value = "XLK"
mock_price_history.return_value = []
result = calculate_relative_strength_metrics("INVALID", "2024-01-15")
assert isinstance(result, dict)
assert result["relative_strength_score"] == 0.5

View File

@ -0,0 +1,121 @@
import pytest
class TestCalculateStopLoss:
def test_calculate_stop_loss_default_multiplier(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_stop_loss,
)
stop = calculate_stop_loss(price=100.0, atr=2.0, multiplier=1.5)
assert stop == 97.0
def test_calculate_stop_loss_custom_multiplier(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_stop_loss,
)
stop = calculate_stop_loss(price=100.0, atr=2.0, multiplier=2.0)
assert stop == 96.0
def test_calculate_stop_loss_large_atr(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_stop_loss,
)
stop = calculate_stop_loss(price=100.0, atr=10.0, multiplier=1.5)
assert stop == 85.0
class TestCalculateRewardTarget:
def test_calculate_reward_target(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_reward_target,
)
target = calculate_reward_target(price=100.0, resistance=120.0)
assert target == 120.0
def test_calculate_reward_target_resistance_below_price(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_reward_target,
)
target = calculate_reward_target(price=100.0, resistance=90.0)
assert target == 90.0
class TestCalculateRiskRewardRatio:
def test_calculate_rr_ratio_good_trade(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_ratio,
)
rr = calculate_risk_reward_ratio(price=100.0, stop=95.0, target=115.0)
assert rr == 3.0
def test_calculate_rr_ratio_poor_trade(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_ratio,
)
rr = calculate_risk_reward_ratio(price=100.0, stop=95.0, target=102.0)
assert rr == pytest.approx(0.4, rel=0.01)
def test_calculate_rr_ratio_stop_at_price_returns_zero(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_ratio,
)
rr = calculate_risk_reward_ratio(price=100.0, stop=100.0, target=110.0)
assert rr == 0.0
def test_calculate_rr_ratio_target_below_price(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_ratio,
)
rr = calculate_risk_reward_ratio(price=100.0, stop=95.0, target=98.0)
assert rr < 0
class TestCalculateRiskRewardScore:
def test_excellent_rr_ratio_high_score(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_score,
)
score = calculate_risk_reward_score(rr_ratio=3.5)
assert 0.9 <= score <= 1.0
def test_good_rr_ratio_moderate_high_score(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_score,
)
score = calculate_risk_reward_score(rr_ratio=2.5)
assert 0.7 <= score <= 0.9
def test_acceptable_rr_ratio_moderate_score(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_score,
)
score = calculate_risk_reward_score(rr_ratio=1.5)
assert 0.4 <= score <= 0.7
def test_poor_rr_ratio_low_score(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_score,
)
score = calculate_risk_reward_score(rr_ratio=0.5)
assert 0.0 <= score <= 0.4
def test_negative_rr_ratio_very_low_score(self):
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_risk_reward_score,
)
score = calculate_risk_reward_score(rr_ratio=-1.0)
assert score == 0.0

View File

@ -0,0 +1,224 @@
from unittest.mock import MagicMock, patch
import pytest
class TestFindSupportLevels:
def test_find_support_from_lows(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
find_support_levels,
)
lows = (
[100.0, 98.0, 97.0, 99.0, 96.0, 95.0, 97.0, 98.0, 94.0, 93.0]
+ [95.0, 96.0, 94.0, 93.0, 92.0, 91.0, 90.0, 89.0, 88.0, 87.0]
+ [88.0, 89.0, 87.0, 86.0, 85.0] * 6
)
support_20d, support_50d = find_support_levels(
lows, lookback_20=20, lookback_50=50
)
assert support_20d <= 100.0
assert support_50d <= 100.0
def test_find_support_empty_list(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
find_support_levels,
)
lows = []
support_20d, support_50d = find_support_levels(
lows, lookback_20=20, lookback_50=50
)
assert support_20d == 0.0
assert support_50d == 0.0
class TestFindResistanceLevels:
def test_find_resistance_from_highs(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
find_resistance_levels,
)
highs = (
[100.0, 102.0, 105.0, 103.0, 108.0, 106.0, 110.0, 107.0, 112.0, 109.0]
+ [111.0, 113.0, 115.0, 114.0, 117.0, 116.0, 120.0, 118.0, 122.0, 119.0]
+ [120.0, 121.0, 123.0, 124.0, 125.0] * 6
)
resistance_20d, resistance_50d = find_resistance_levels(
highs, lookback_20=20, lookback_50=50
)
assert resistance_20d >= 100.0
assert resistance_50d >= 100.0
def test_find_resistance_empty_list(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
find_resistance_levels,
)
highs = []
resistance_20d, resistance_50d = find_resistance_levels(
highs, lookback_20=20, lookback_50=50
)
assert resistance_20d == 0.0
assert resistance_50d == 0.0
class TestDetectSwingPoints:
def test_detect_swing_highs_and_lows(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
detect_swing_points,
)
prices = [
100,
102,
105,
103,
98,
95,
97,
100,
103,
107,
105,
102,
99,
96,
94,
97,
101,
]
swing_lows, swing_highs = detect_swing_points(prices, n_bars=3)
assert len(swing_lows) >= 0
assert len(swing_highs) >= 0
def test_detect_swing_points_short_list(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
detect_swing_points,
)
prices = [100, 105, 102]
swing_lows, swing_highs = detect_swing_points(prices, n_bars=3)
assert isinstance(swing_lows, list)
assert isinstance(swing_highs, list)
def test_detect_swing_points_empty_list(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
detect_swing_points,
)
prices = []
swing_lows, swing_highs = detect_swing_points(prices, n_bars=5)
assert swing_lows == []
assert swing_highs == []
class TestGetNearestLevels:
def test_get_nearest_support_and_resistance(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
get_nearest_levels,
)
price = 150.0
supports = [140.0, 145.0, 130.0, 120.0]
resistances = [155.0, 160.0, 170.0, 180.0]
nearest_support, nearest_resistance = get_nearest_levels(
price, supports, resistances
)
assert nearest_support == 145.0
assert nearest_resistance == 155.0
def test_get_nearest_levels_no_support_below(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
get_nearest_levels,
)
price = 100.0
supports = [110.0, 120.0, 130.0]
resistances = [150.0, 160.0]
nearest_support, nearest_resistance = get_nearest_levels(
price, supports, resistances
)
assert nearest_support == 0.0
assert nearest_resistance == 150.0
def test_get_nearest_levels_empty_lists(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
get_nearest_levels,
)
price = 150.0
supports = []
resistances = []
nearest_support, nearest_resistance = get_nearest_levels(
price, supports, resistances
)
assert nearest_support == 0.0
assert nearest_resistance == 0.0
class TestCalculateSupportResistanceMetrics:
@patch(
"tradingagents.agents.discovery.indicators.support_resistance._get_ohlc_data"
)
@patch("tradingagents.agents.discovery.indicators.support_resistance._get_atr")
def test_calculate_sr_metrics_returns_dict(self, mock_atr, mock_ohlc):
from tradingagents.agents.discovery.indicators.support_resistance import (
calculate_support_resistance_metrics,
)
mock_ohlc.return_value = {
"highs": [105.0 + i * 0.5 for i in range(60)],
"lows": [95.0 + i * 0.3 for i in range(60)],
"closes": [100.0 + i * 0.4 for i in range(60)],
"current_price": 125.0,
}
mock_atr.return_value = 2.5
result = calculate_support_resistance_metrics("AAPL", "2024-01-15")
assert isinstance(result, dict)
assert "support_level" in result
assert "resistance_level" in result
assert "atr" in result
assert "suggested_stop" in result
assert "reward_target" in result
assert "risk_reward_ratio" in result
assert "risk_reward_score" in result
@patch(
"tradingagents.agents.discovery.indicators.support_resistance._get_ohlc_data"
)
@patch("tradingagents.agents.discovery.indicators.support_resistance._get_atr")
def test_calculate_sr_metrics_handles_missing_data(self, mock_atr, mock_ohlc):
from tradingagents.agents.discovery.indicators.support_resistance import (
calculate_support_resistance_metrics,
)
mock_ohlc.return_value = {
"highs": [],
"lows": [],
"closes": [],
"current_price": None,
}
mock_atr.return_value = None
result = calculate_support_resistance_metrics("INVALID", "2024-01-15")
assert isinstance(result, dict)
assert result["risk_reward_score"] == 0.5

View File

@ -0,0 +1,201 @@
from unittest.mock import MagicMock, patch
import pytest
class TestCalculateVolumeRatio:
def test_volume_ratio_above_average(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_ratio,
)
ratio = calculate_volume_ratio(
current_volume=2_000_000, avg_volume_20d=1_000_000
)
assert ratio == 2.0
def test_volume_ratio_below_average(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_ratio,
)
ratio = calculate_volume_ratio(current_volume=500_000, avg_volume_20d=1_000_000)
assert ratio == 0.5
def test_volume_ratio_zero_average_returns_zero(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_ratio,
)
ratio = calculate_volume_ratio(current_volume=1_000_000, avg_volume_20d=0)
assert ratio == 0.0
class TestCalculateVolumeTrend:
def test_volume_trend_increasing(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_trend,
)
volume_series = [100_000, 150_000, 200_000, 250_000, 300_000]
trend, slope = calculate_volume_trend(volume_series)
assert trend == "increasing"
assert slope > 0
def test_volume_trend_decreasing(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_trend,
)
volume_series = [300_000, 250_000, 200_000, 150_000, 100_000]
trend, slope = calculate_volume_trend(volume_series)
assert trend == "decreasing"
assert slope < 0
def test_volume_trend_flat(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_trend,
)
volume_series = [100_000, 100_100, 99_900, 100_050, 99_950]
trend, slope = calculate_volume_trend(volume_series)
assert trend == "flat"
def test_volume_trend_empty_returns_flat(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_trend,
)
volume_series = []
trend, slope = calculate_volume_trend(volume_series)
assert trend == "flat"
assert slope == 0.0
class TestCalculateDollarVolume:
def test_dollar_volume_calculation(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_dollar_volume,
)
dollar_vol = calculate_dollar_volume(price=150.0, volume=1_000_000)
assert dollar_vol == 150_000_000.0
def test_dollar_volume_zero_price(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_dollar_volume,
)
dollar_vol = calculate_dollar_volume(price=0.0, volume=1_000_000)
assert dollar_vol == 0.0
class TestCalculateVolumeScore:
def test_volume_spike_with_positive_price_high_score(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_score,
)
score = calculate_volume_score(
volume_ratio=2.5,
trend="increasing",
dollar_volume=50_000_000.0,
price_change=5.0,
min_dollar_volume=1_000_000.0,
)
assert 0.8 <= score <= 1.0
def test_above_average_volume_moderate_score(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_score,
)
score = calculate_volume_score(
volume_ratio=1.5,
trend="increasing",
dollar_volume=10_000_000.0,
price_change=2.0,
min_dollar_volume=1_000_000.0,
)
assert 0.7 <= score <= 0.9
def test_normal_volume_neutral_score(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_score,
)
score = calculate_volume_score(
volume_ratio=1.0,
trend="flat",
dollar_volume=5_000_000.0,
price_change=0.5,
min_dollar_volume=1_000_000.0,
)
assert 0.3 <= score <= 0.6
def test_low_dollar_volume_penalized(self):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_score,
)
high_volume_score = calculate_volume_score(
volume_ratio=2.0,
trend="increasing",
dollar_volume=50_000_000.0,
price_change=3.0,
min_dollar_volume=1_000_000.0,
)
low_volume_score = calculate_volume_score(
volume_ratio=2.0,
trend="increasing",
dollar_volume=500_000.0,
price_change=3.0,
min_dollar_volume=1_000_000.0,
)
assert low_volume_score < high_volume_score
class TestCalculateVolumeMetrics:
@patch("tradingagents.agents.discovery.indicators.volume._get_volume_price_data")
def test_calculate_volume_metrics_returns_dict(self, mock_get_data):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_metrics,
)
mock_get_data.return_value = {
"volumes": [1_000_000] * 25,
"prices": [100.0] * 25,
"current_price": 105.0,
"current_volume": 1_500_000,
"price_change_pct": 5.0,
}
result = calculate_volume_metrics("AAPL", "2024-01-15")
assert isinstance(result, dict)
assert "volume_ratio" in result
assert "volume_trend" in result
assert "dollar_volume" in result
assert "volume_score" in result
assert 0.0 <= result["volume_score"] <= 1.0
@patch("tradingagents.agents.discovery.indicators.volume._get_volume_price_data")
def test_calculate_volume_metrics_handles_missing_data(self, mock_get_data):
from tradingagents.agents.discovery.indicators.volume import (
calculate_volume_metrics,
)
mock_get_data.return_value = {
"volumes": [],
"prices": [],
"current_price": None,
"current_volume": None,
"price_change_pct": 0.0,
}
result = calculate_volume_metrics("INVALID", "2024-01-15")
assert isinstance(result, dict)
assert result["volume_score"] == 0.5

View File

@ -22,6 +22,7 @@ from .persistence import (
generate_markdown_summary, generate_markdown_summary,
save_discovery_result, save_discovery_result,
) )
from .quantitative_models import QuantitativeMetrics
from .scorer import ( from .scorer import (
DEFAULT_DECAY_RATE, DEFAULT_DECAY_RATE,
DEFAULT_MAX_RESULTS, DEFAULT_MAX_RESULTS,
@ -50,4 +51,5 @@ __all__ = [
"DEFAULT_MIN_MENTIONS", "DEFAULT_MIN_MENTIONS",
"save_discovery_result", "save_discovery_result",
"generate_markdown_summary", "generate_markdown_summary",
"QuantitativeMetrics",
] ]

View File

@ -0,0 +1,61 @@
from tradingagents.agents.discovery.indicators.momentum import (
calculate_ema_direction,
calculate_macd_score,
calculate_momentum_score,
calculate_rsi_score,
calculate_sma_score,
)
from tradingagents.agents.discovery.indicators.relative_strength import (
SECTOR_ETF_MAP,
calculate_relative_strength,
calculate_relative_strength_metrics,
calculate_return,
get_sector_etf,
)
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_reward_target,
calculate_risk_reward_ratio,
calculate_risk_reward_score,
calculate_stop_loss,
)
from tradingagents.agents.discovery.indicators.support_resistance import (
calculate_support_resistance_metrics,
detect_swing_points,
find_resistance_levels,
find_support_levels,
get_nearest_levels,
)
from tradingagents.agents.discovery.indicators.volume import (
calculate_dollar_volume,
calculate_volume_metrics,
calculate_volume_ratio,
calculate_volume_score,
calculate_volume_trend,
)
__all__ = [
"calculate_rsi_score",
"calculate_macd_score",
"calculate_sma_score",
"calculate_ema_direction",
"calculate_momentum_score",
"calculate_volume_ratio",
"calculate_volume_trend",
"calculate_dollar_volume",
"calculate_volume_score",
"calculate_volume_metrics",
"SECTOR_ETF_MAP",
"calculate_return",
"calculate_relative_strength",
"get_sector_etf",
"calculate_relative_strength_metrics",
"find_support_levels",
"find_resistance_levels",
"detect_swing_points",
"get_nearest_levels",
"calculate_support_resistance_metrics",
"calculate_stop_loss",
"calculate_reward_target",
"calculate_risk_reward_ratio",
"calculate_risk_reward_score",
]

View File

@ -0,0 +1,349 @@
import logging
from datetime import datetime
import pandas as pd
logger = logging.getLogger(__name__)
def calculate_rsi_score(rsi: float) -> float:
if rsi <= 20:
return 1.0
elif rsi <= 30:
return 0.8 + (30 - rsi) / 50
elif rsi <= 35:
return 0.7 + (35 - rsi) / 50
elif rsi <= 50:
return 0.6 + (50 - rsi) / 150
elif rsi <= 65:
return 0.5 + (65 - rsi) / 150
elif rsi <= 70:
return 0.4 + (70 - rsi) / 50
elif rsi <= 80:
return 0.2 + (80 - rsi) / 50
else:
return max(0.0, 0.2 - (rsi - 80) / 100)
def calculate_macd_score(macd: float, signal: float, histogram: float) -> float:
score = 0.5
if macd > signal:
crossover_strength = min((macd - signal) / max(abs(signal), 0.01), 1.0)
score += 0.25 * crossover_strength
else:
crossover_weakness = min((signal - macd) / max(abs(signal), 0.01), 1.0)
score -= 0.25 * crossover_weakness
if histogram > 0:
histogram_bonus = min(histogram / max(abs(macd), 0.01), 1.0) * 0.25
score += histogram_bonus
else:
histogram_penalty = min(abs(histogram) / max(abs(macd), 0.01), 1.0) * 0.25
score -= histogram_penalty
return max(0.0, min(1.0, score))
def calculate_sma_score(price: float, sma50: float, sma200: float) -> float:
if sma50 == 0 or sma200 == 0:
return 0.5
score = 0.5
pct_vs_sma50 = (price - sma50) / sma50
pct_vs_sma200 = (price - sma200) / sma200
if pct_vs_sma50 > 0:
score += min(pct_vs_sma50 * 2, 0.25)
else:
score += max(pct_vs_sma50 * 2, -0.25)
if pct_vs_sma200 > 0:
score += min(pct_vs_sma200 * 2, 0.25)
else:
score += max(pct_vs_sma200 * 2, -0.25)
if price > sma50 > sma200:
score += 0.15
return max(0.0, min(1.0, score))
def calculate_ema_direction(ema_values: list[float]) -> str:
if len(ema_values) < 2:
return "flat"
first_half = ema_values[: len(ema_values) // 2]
second_half = ema_values[len(ema_values) // 2 :]
if not first_half or not second_half:
return "flat"
first_avg = sum(first_half) / len(first_half)
second_avg = sum(second_half) / len(second_half)
pct_change = (second_avg - first_avg) / first_avg if first_avg != 0 else 0
if pct_change > 0.01:
return "up"
elif pct_change < -0.01:
return "down"
else:
return "flat"
def _get_stock_stats_bulk(symbol: str, indicator: str, curr_date: str) -> dict:
import os
from stockstats import wrap
from tradingagents.dataflows.config import get_config
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
if not online:
try:
data = pd.read_csv(
os.path.join(
config.get("data_cache_dir", "data"),
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
df = wrap(data)
except FileNotFoundError:
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
else:
import yfinance as yf
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
data["Date"] = pd.to_datetime(data["Date"])
else:
data = yf.download(
symbol,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
df[indicator]
indicator_series = df[indicator].apply(lambda x: "N/A" if pd.isna(x) else str(x))
result_dict = dict(zip(df["Date"], indicator_series, strict=False))
return result_dict
def _get_price_data(symbol: str, curr_date: str) -> dict:
import os
from tradingagents.dataflows.config import get_config
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
if not online:
try:
data = pd.read_csv(
os.path.join(
config.get("data_cache_dir", "data"),
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
except FileNotFoundError:
return {}
else:
import yfinance as yf
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
else:
data = yf.download(
symbol,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
data["Date"] = pd.to_datetime(data["Date"]).dt.strftime("%Y-%m-%d")
result_dict = {}
for _, row in data.iterrows():
result_dict[row["Date"]] = row["Close"]
return result_dict
def calculate_momentum_score(ticker: str, curr_date: str) -> dict:
result = {
"rsi": None,
"macd": None,
"macd_signal": None,
"macd_histogram": None,
"price_vs_sma50": None,
"price_vs_sma200": None,
"ema10_direction": None,
"momentum_score": 0.5,
}
try:
rsi_data = _get_stock_stats_bulk(ticker, "rsi", curr_date)
macd_data = _get_stock_stats_bulk(ticker, "macd", curr_date)
macds_data = _get_stock_stats_bulk(ticker, "macds", curr_date)
macdh_data = _get_stock_stats_bulk(ticker, "macdh", curr_date)
sma50_data = _get_stock_stats_bulk(ticker, "close_50_sma", curr_date)
sma200_data = _get_stock_stats_bulk(ticker, "close_200_sma", curr_date)
ema10_data = _get_stock_stats_bulk(ticker, "close_10_ema", curr_date)
price_data = _get_price_data(ticker, curr_date)
rsi_value = None
if curr_date in rsi_data and rsi_data[curr_date] != "N/A":
try:
rsi_value = float(rsi_data[curr_date])
result["rsi"] = rsi_value
except (ValueError, TypeError):
pass
macd_value = None
macds_value = None
macdh_value = None
if curr_date in macd_data and macd_data[curr_date] != "N/A":
try:
macd_value = float(macd_data[curr_date])
result["macd"] = macd_value
except (ValueError, TypeError):
pass
if curr_date in macds_data and macds_data[curr_date] != "N/A":
try:
macds_value = float(macds_data[curr_date])
result["macd_signal"] = macds_value
except (ValueError, TypeError):
pass
if curr_date in macdh_data and macdh_data[curr_date] != "N/A":
try:
macdh_value = float(macdh_data[curr_date])
result["macd_histogram"] = macdh_value
except (ValueError, TypeError):
pass
current_price = None
sma50_value = None
sma200_value = None
if curr_date in price_data:
try:
current_price = float(price_data[curr_date])
except (ValueError, TypeError):
pass
if curr_date in sma50_data and sma50_data[curr_date] != "N/A":
try:
sma50_value = float(sma50_data[curr_date])
except (ValueError, TypeError):
pass
if curr_date in sma200_data and sma200_data[curr_date] != "N/A":
try:
sma200_value = float(sma200_data[curr_date])
except (ValueError, TypeError):
pass
if current_price and sma50_value:
result["price_vs_sma50"] = (
(current_price - sma50_value) / sma50_value
) * 100
if current_price and sma200_value:
result["price_vs_sma200"] = (
(current_price - sma200_value) / sma200_value
) * 100
ema_values = []
sorted_dates = sorted(
[d for d in ema10_data.keys() if d <= curr_date], reverse=True
)[:10]
for date in sorted_dates:
if ema10_data[date] != "N/A":
try:
ema_values.append(float(ema10_data[date]))
except (ValueError, TypeError):
pass
if ema_values:
result["ema10_direction"] = calculate_ema_direction(ema_values[::-1])
scores = []
weights = []
if rsi_value is not None:
scores.append(calculate_rsi_score(rsi_value))
weights.append(0.25)
if (
macd_value is not None
and macds_value is not None
and macdh_value is not None
):
scores.append(calculate_macd_score(macd_value, macds_value, macdh_value))
weights.append(0.35)
if current_price and sma50_value and sma200_value:
scores.append(calculate_sma_score(current_price, sma50_value, sma200_value))
weights.append(0.40)
if scores and weights:
total_weight = sum(weights)
result["momentum_score"] = (
sum(s * w for s, w in zip(scores, weights, strict=False)) / total_weight
)
else:
result["momentum_score"] = 0.5
except (KeyError, ValueError, FileNotFoundError, RuntimeError) as e:
logger.warning("Failed to calculate momentum score for %s: %s", ticker, str(e))
result["momentum_score"] = 0.5
return result

View File

@ -0,0 +1,181 @@
import logging
import os
import pandas as pd
from tradingagents.dataflows.trending.sector_classifier import classify_sector
logger = logging.getLogger(__name__)
SECTOR_ETF_MAP = {
"technology": "XLK",
"finance": "XLF",
"healthcare": "XLV",
"energy": "XLE",
"consumer_goods": "XLY",
"industrials": "XLI",
"other": "SPY",
}
def calculate_return(prices: list[float], days: int) -> float:
if len(prices) < 2:
return 0.0
start_idx = max(0, len(prices) - days)
start_price = prices[start_idx]
end_price = prices[-1]
if start_price == 0:
return 0.0
return ((end_price - start_price) / start_price) * 100
def calculate_relative_strength(stock_return: float, benchmark_return: float) -> float:
return stock_return - benchmark_return
def get_sector_etf(ticker: str) -> str:
sector = classify_sector(ticker)
return SECTOR_ETF_MAP.get(sector, "SPY")
def _get_price_history(ticker: str, curr_date: str, days: int) -> list[float]:
from tradingagents.dataflows.config import get_config
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
try:
if not online:
data = pd.read_csv(
os.path.join(
config.get("data_cache_dir", "data"),
f"{ticker}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
else:
import yfinance as yf
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{ticker}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
else:
data = yf.download(
ticker,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
data["Date"] = pd.to_datetime(data["Date"]).dt.strftime("%Y-%m-%d")
data = data[data["Date"] <= curr_date].tail(days + 5)
return data["Close"].tolist()
except (FileNotFoundError, KeyError, ValueError) as e:
logger.warning("Failed to get price history for %s: %s", ticker, str(e))
return []
def _calculate_rs_score(
rs_5d: float, rs_20d: float, rs_60d: float, rs_sector: float
) -> float:
score = 0.5
if rs_20d > 5:
score += min(rs_20d / 20, 0.25)
elif rs_20d < -5:
score -= min(abs(rs_20d) / 20, 0.25)
if rs_5d > 3:
score += min(rs_5d / 15, 0.15)
elif rs_5d < -3:
score -= min(abs(rs_5d) / 15, 0.15)
if rs_60d > 10:
score += min(rs_60d / 40, 0.1)
elif rs_60d < -10:
score -= min(abs(rs_60d) / 40, 0.1)
if rs_sector > 3:
score += min(rs_sector / 15, 0.1)
elif rs_sector < -3:
score -= min(abs(rs_sector) / 15, 0.1)
return max(0.0, min(1.0, score))
def calculate_relative_strength_metrics(ticker: str, curr_date: str) -> dict:
result = {
"rs_vs_spy_5d": None,
"rs_vs_spy_20d": None,
"rs_vs_spy_60d": None,
"rs_vs_sector": None,
"sector_etf": None,
"relative_strength_score": 0.5,
}
try:
stock_prices = _get_price_history(ticker, curr_date, 70)
spy_prices = _get_price_history("SPY", curr_date, 70)
if not stock_prices or not spy_prices:
return result
sector_etf = get_sector_etf(ticker)
result["sector_etf"] = sector_etf
sector_prices = []
if sector_etf != "SPY":
sector_prices = _get_price_history(sector_etf, curr_date, 70)
stock_5d = calculate_return(stock_prices, 5)
stock_20d = calculate_return(stock_prices, 20)
stock_60d = calculate_return(stock_prices, 60)
spy_5d = calculate_return(spy_prices, 5)
spy_20d = calculate_return(spy_prices, 20)
spy_60d = calculate_return(spy_prices, 60)
result["rs_vs_spy_5d"] = calculate_relative_strength(stock_5d, spy_5d)
result["rs_vs_spy_20d"] = calculate_relative_strength(stock_20d, spy_20d)
result["rs_vs_spy_60d"] = calculate_relative_strength(stock_60d, spy_60d)
if sector_prices and sector_etf != "SPY":
sector_20d = calculate_return(sector_prices, 20)
result["rs_vs_sector"] = calculate_relative_strength(stock_20d, sector_20d)
else:
result["rs_vs_sector"] = result["rs_vs_spy_20d"]
result["relative_strength_score"] = _calculate_rs_score(
result["rs_vs_spy_5d"],
result["rs_vs_spy_20d"],
result["rs_vs_spy_60d"],
result["rs_vs_sector"],
)
except (KeyError, ValueError, RuntimeError) as e:
logger.warning(
"Failed to calculate relative strength for %s: %s", ticker, str(e)
)
return result

View File

@ -0,0 +1,34 @@
import logging
logger = logging.getLogger(__name__)
def calculate_stop_loss(price: float, atr: float, multiplier: float = 1.5) -> float:
return price - (atr * multiplier)
def calculate_reward_target(price: float, resistance: float) -> float:
return resistance
def calculate_risk_reward_ratio(price: float, stop: float, target: float) -> float:
risk = price - stop
if risk == 0:
return 0.0
reward = target - price
return reward / risk
def calculate_risk_reward_score(rr_ratio: float) -> float:
if rr_ratio < 0:
return 0.0
if rr_ratio >= 3.0:
return 0.9 + min((rr_ratio - 3.0) / 10, 0.1)
elif rr_ratio >= 2.0:
return 0.7 + (rr_ratio - 2.0) / 5
elif rr_ratio >= 1.0:
return 0.4 + (rr_ratio - 1.0) * 0.3
else:
return rr_ratio * 0.4

View File

@ -0,0 +1,260 @@
import logging
import os
import pandas as pd
from tradingagents.agents.discovery.indicators.risk_reward import (
calculate_reward_target,
calculate_risk_reward_ratio,
calculate_risk_reward_score,
calculate_stop_loss,
)
logger = logging.getLogger(__name__)
def find_support_levels(
lows: list[float], lookback_20: int, lookback_50: int
) -> tuple[float, float]:
if not lows:
return (0.0, 0.0)
support_20d = (
min(lows[-lookback_20:])
if len(lows) >= lookback_20
else min(lows)
if lows
else 0.0
)
support_50d = (
min(lows[-lookback_50:])
if len(lows) >= lookback_50
else min(lows)
if lows
else 0.0
)
return (support_20d, support_50d)
def find_resistance_levels(
highs: list[float], lookback_20: int, lookback_50: int
) -> tuple[float, float]:
if not highs:
return (0.0, 0.0)
resistance_20d = (
max(highs[-lookback_20:])
if len(highs) >= lookback_20
else max(highs)
if highs
else 0.0
)
resistance_50d = (
max(highs[-lookback_50:])
if len(highs) >= lookback_50
else max(highs)
if highs
else 0.0
)
return (resistance_20d, resistance_50d)
def detect_swing_points(
prices: list[float], n_bars: int = 5
) -> tuple[list[float], list[float]]:
if len(prices) < (2 * n_bars + 1):
return ([], [])
swing_lows = []
swing_highs = []
for i in range(n_bars, len(prices) - n_bars):
is_swing_low = True
is_swing_high = True
current = prices[i]
for j in range(1, n_bars + 1):
if prices[i - j] <= current or prices[i + j] <= current:
is_swing_low = False
if prices[i - j] >= current or prices[i + j] >= current:
is_swing_high = False
if is_swing_low:
swing_lows.append(current)
if is_swing_high:
swing_highs.append(current)
return (swing_lows, swing_highs)
def get_nearest_levels(
price: float, supports: list[float], resistances: list[float]
) -> tuple[float, float]:
supports_below = [s for s in supports if s < price]
resistances_above = [r for r in resistances if r > price]
nearest_support = max(supports_below) if supports_below else 0.0
nearest_resistance = min(resistances_above) if resistances_above else 0.0
return (nearest_support, nearest_resistance)
def _get_ohlc_data(ticker: str, curr_date: str) -> dict:
from tradingagents.dataflows.config import get_config
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
result = {
"highs": [],
"lows": [],
"closes": [],
"current_price": None,
}
try:
if not online:
data = pd.read_csv(
os.path.join(
config.get("data_cache_dir", "data"),
f"{ticker}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
else:
import yfinance as yf
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{ticker}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
else:
data = yf.download(
ticker,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
data["Date"] = pd.to_datetime(data["Date"]).dt.strftime("%Y-%m-%d")
data = data[data["Date"] <= curr_date].tail(60)
if len(data) > 0:
result["highs"] = data["High"].tolist()
result["lows"] = data["Low"].tolist()
result["closes"] = data["Close"].tolist()
result["current_price"] = data["Close"].iloc[-1] if len(data) > 0 else None
except (FileNotFoundError, KeyError, ValueError) as e:
logger.warning("Failed to get OHLC data for %s: %s", ticker, str(e))
return result
def _get_atr(ticker: str, curr_date: str) -> float | None:
from tradingagents.agents.discovery.indicators.momentum import _get_stock_stats_bulk
try:
atr_data = _get_stock_stats_bulk(ticker, "atr", curr_date)
if curr_date in atr_data and atr_data[curr_date] != "N/A":
return float(atr_data[curr_date])
except (KeyError, ValueError, FileNotFoundError) as e:
logger.warning("Failed to get ATR for %s: %s", ticker, str(e))
return None
def calculate_support_resistance_metrics(ticker: str, curr_date: str) -> dict:
result = {
"support_level": None,
"resistance_level": None,
"atr": None,
"suggested_stop": None,
"reward_target": None,
"risk_reward_ratio": None,
"risk_reward_score": 0.5,
}
try:
ohlc = _get_ohlc_data(ticker, curr_date)
atr = _get_atr(ticker, curr_date)
highs = ohlc["highs"]
lows = ohlc["lows"]
closes = ohlc["closes"]
current_price = ohlc["current_price"]
if not highs or not lows or current_price is None:
return result
result["atr"] = atr
support_20d, support_50d = find_support_levels(lows, 20, 50)
resistance_20d, resistance_50d = find_resistance_levels(highs, 20, 50)
swing_lows, swing_highs = detect_swing_points(closes, n_bars=5)
all_supports = [s for s in [support_20d, support_50d] + swing_lows if s > 0]
all_resistances = [
r for r in [resistance_20d, resistance_50d] + swing_highs if r > 0
]
nearest_support, nearest_resistance = get_nearest_levels(
current_price, all_supports, all_resistances
)
result["support_level"] = (
nearest_support if nearest_support > 0 else support_20d
)
result["resistance_level"] = (
nearest_resistance if nearest_resistance > 0 else resistance_20d
)
if atr and atr > 0:
result["suggested_stop"] = calculate_stop_loss(
current_price, atr, multiplier=1.5
)
else:
pct_to_support = (
(current_price - result["support_level"]) / current_price
if result["support_level"] > 0
else 0.02
)
result["suggested_stop"] = current_price * (1 - max(pct_to_support, 0.02))
result["reward_target"] = calculate_reward_target(
current_price, result["resistance_level"]
)
if result["suggested_stop"] and result["reward_target"]:
result["risk_reward_ratio"] = calculate_risk_reward_ratio(
current_price, result["suggested_stop"], result["reward_target"]
)
result["risk_reward_score"] = calculate_risk_reward_score(
result["risk_reward_ratio"]
)
except (KeyError, ValueError, RuntimeError) as e:
logger.warning(
"Failed to calculate support/resistance metrics for %s: %s", ticker, str(e)
)
return result

View File

@ -0,0 +1,203 @@
import logging
logger = logging.getLogger(__name__)
def determine_signal_from_indicators(
rsi: float | None,
macd: float | None,
macd_signal: float | None,
price_vs_sma: float | None,
ema_direction: str | None,
) -> str:
bullish_signals = 0
bearish_signals = 0
total_signals = 0
if rsi is not None:
total_signals += 1
if rsi < 40:
bullish_signals += 1
elif rsi > 60:
bearish_signals += 1
if macd is not None and macd_signal is not None:
total_signals += 1
if macd > macd_signal:
bullish_signals += 1
else:
bearish_signals += 1
if price_vs_sma is not None:
total_signals += 1
if price_vs_sma > 0:
bullish_signals += 1
elif price_vs_sma < 0:
bearish_signals += 1
if ema_direction is not None:
total_signals += 1
if ema_direction == "up":
bullish_signals += 1
elif ema_direction == "down":
bearish_signals += 1
if total_signals == 0:
return "neutral"
bullish_ratio = bullish_signals / total_signals
bearish_ratio = bearish_signals / total_signals
if bullish_ratio >= 0.6:
return "bullish"
elif bearish_ratio >= 0.6:
return "bearish"
else:
return "neutral"
def calculate_timeframe_signals(
momentum_data: dict,
relative_strength_data: dict,
) -> dict:
result = {
"short_term_signal": "neutral",
"medium_term_signal": "neutral",
"long_term_signal": "neutral",
"timeframe_alignment": "neutral",
"signal_strength": 0.5,
}
try:
rsi = momentum_data.get("rsi")
macd = momentum_data.get("macd")
macd_signal = momentum_data.get("macd_signal")
ema_direction = momentum_data.get("ema10_direction")
price_vs_sma50 = momentum_data.get("price_vs_sma50")
price_vs_sma200 = momentum_data.get("price_vs_sma200")
rs_5d = relative_strength_data.get("rs_vs_spy_5d")
rs_20d = relative_strength_data.get("rs_vs_spy_20d")
rs_60d = relative_strength_data.get("rs_vs_spy_60d")
short_bullish = 0
short_bearish = 0
short_total = 0
if rsi is not None:
short_total += 1
if rsi < 35:
short_bullish += 1
elif rsi > 65:
short_bearish += 1
if ema_direction == "up":
short_total += 1
short_bullish += 1
elif ema_direction == "down":
short_total += 1
short_bearish += 1
if rs_5d is not None:
short_total += 1
if rs_5d > 0:
short_bullish += 1
elif rs_5d < 0:
short_bearish += 1
if short_total > 0:
if short_bullish / short_total >= 0.6:
result["short_term_signal"] = "bullish"
elif short_bearish / short_total >= 0.6:
result["short_term_signal"] = "bearish"
med_bullish = 0
med_bearish = 0
med_total = 0
if macd is not None and macd_signal is not None:
med_total += 1
if macd > macd_signal:
med_bullish += 1
else:
med_bearish += 1
if price_vs_sma50 is not None:
med_total += 1
if price_vs_sma50 > 0:
med_bullish += 1
elif price_vs_sma50 < -2:
med_bearish += 1
if rs_20d is not None:
med_total += 1
if rs_20d > 0:
med_bullish += 1
elif rs_20d < 0:
med_bearish += 1
if med_total > 0:
if med_bullish / med_total >= 0.6:
result["medium_term_signal"] = "bullish"
elif med_bearish / med_total >= 0.6:
result["medium_term_signal"] = "bearish"
long_bullish = 0
long_bearish = 0
long_total = 0
if price_vs_sma200 is not None:
long_total += 1
if price_vs_sma200 > 0:
long_bullish += 1
elif price_vs_sma200 < -5:
long_bearish += 1
if rs_60d is not None:
long_total += 1
if rs_60d > 0:
long_bullish += 1
elif rs_60d < 0:
long_bearish += 1
if price_vs_sma50 is not None and price_vs_sma200 is not None:
long_total += 1
if price_vs_sma50 > 0 and price_vs_sma200 > 0:
long_bullish += 1
elif price_vs_sma50 < 0 and price_vs_sma200 < 0:
long_bearish += 1
if long_total > 0:
if long_bullish / long_total >= 0.6:
result["long_term_signal"] = "bullish"
elif long_bearish / long_total >= 0.6:
result["long_term_signal"] = "bearish"
signals = [
result["short_term_signal"],
result["medium_term_signal"],
result["long_term_signal"],
]
bullish_count = signals.count("bullish")
bearish_count = signals.count("bearish")
if bullish_count == 3:
result["timeframe_alignment"] = "aligned_bullish"
result["signal_strength"] = 1.0
elif bearish_count == 3:
result["timeframe_alignment"] = "aligned_bearish"
result["signal_strength"] = 0.0
elif bullish_count >= 2:
result["timeframe_alignment"] = "mixed"
result["signal_strength"] = 0.7
elif bearish_count >= 2:
result["timeframe_alignment"] = "mixed"
result["signal_strength"] = 0.3
else:
result["timeframe_alignment"] = "neutral"
result["signal_strength"] = 0.5
except (KeyError, TypeError, ValueError) as e:
logger.warning("Failed to calculate timeframe signals: %s", str(e))
return result

View File

@ -0,0 +1,213 @@
import logging
import os
import numpy as np
import pandas as pd
from tradingagents.config import get_settings
logger = logging.getLogger(__name__)
def calculate_volume_ratio(current_volume: float, avg_volume_20d: float) -> float:
if avg_volume_20d == 0:
return 0.0
return current_volume / avg_volume_20d
def calculate_volume_trend(volume_series: list[float]) -> tuple[str, float]:
if len(volume_series) < 2:
return ("flat", 0.0)
x = np.arange(len(volume_series))
y = np.array(volume_series)
coeffs = np.polyfit(x, y, 1)
slope = coeffs[0]
avg_volume = np.mean(y)
if avg_volume == 0:
return ("flat", 0.0)
normalized_slope = slope / avg_volume
if normalized_slope > 0.02:
return ("increasing", float(slope))
elif normalized_slope < -0.02:
return ("decreasing", float(slope))
else:
return ("flat", float(slope))
def calculate_dollar_volume(price: float, volume: float) -> float:
return price * volume
def calculate_volume_score(
volume_ratio: float,
trend: str,
dollar_volume: float,
price_change: float,
min_dollar_volume: float,
) -> float:
score = 0.5
if volume_ratio >= 2.0:
if price_change > 0:
score += 0.35
else:
score += 0.15
elif volume_ratio >= 1.5:
if price_change > 0:
score += 0.25
else:
score += 0.10
elif volume_ratio >= 1.0:
score += 0.05
else:
score -= 0.1
if trend == "increasing":
score += 0.1
elif trend == "decreasing":
score -= 0.05
if dollar_volume < min_dollar_volume:
liquidity_penalty = min(
0.3, (min_dollar_volume - dollar_volume) / min_dollar_volume * 0.3
)
score -= liquidity_penalty
return max(0.0, min(1.0, score))
def _get_volume_price_data(ticker: str, curr_date: str) -> dict:
from tradingagents.dataflows.config import get_config
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
result = {
"volumes": [],
"prices": [],
"current_price": None,
"current_volume": None,
"price_change_pct": 0.0,
}
try:
if not online:
data = pd.read_csv(
os.path.join(
config.get("data_cache_dir", "data"),
f"{ticker}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
else:
import yfinance as yf
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{ticker}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
else:
data = yf.download(
ticker,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
data["Date"] = pd.to_datetime(data["Date"]).dt.strftime("%Y-%m-%d")
data = data[data["Date"] <= curr_date].tail(30)
if len(data) > 0:
volumes = data["Volume"].tolist()
prices = data["Close"].tolist()
result["volumes"] = volumes[:-1] if len(volumes) > 1 else volumes
result["prices"] = prices[:-1] if len(prices) > 1 else prices
result["current_volume"] = volumes[-1] if volumes else None
result["current_price"] = prices[-1] if prices else None
if len(prices) >= 2:
result["price_change_pct"] = (
(prices[-1] - prices[-2]) / prices[-2] * 100
if prices[-2] != 0
else 0
)
except (FileNotFoundError, KeyError, ValueError) as e:
logger.warning("Failed to get volume/price data for %s: %s", ticker, str(e))
return result
def calculate_volume_metrics(ticker: str, curr_date: str) -> dict:
result = {
"volume_ratio": None,
"volume_trend": None,
"volume_trend_slope": None,
"dollar_volume": None,
"volume_score": 0.5,
}
try:
settings = get_settings()
min_dollar_volume = settings.min_dollar_volume
data = _get_volume_price_data(ticker, curr_date)
volumes = data["volumes"]
current_volume = data["current_volume"]
current_price = data["current_price"]
price_change = data["price_change_pct"]
if not volumes or current_volume is None or current_price is None:
return result
avg_volume_20d = (
sum(volumes[-20:]) / min(len(volumes[-20:]), 20) if volumes else 0
)
volume_ratio = calculate_volume_ratio(current_volume, avg_volume_20d)
result["volume_ratio"] = volume_ratio
trend_volumes = (
volumes[-10:] + [current_volume] if volumes else [current_volume]
)
trend, slope = calculate_volume_trend(trend_volumes)
result["volume_trend"] = trend
result["volume_trend_slope"] = slope
dollar_volume = calculate_dollar_volume(current_price, current_volume)
result["dollar_volume"] = dollar_volume
result["volume_score"] = calculate_volume_score(
volume_ratio=volume_ratio,
trend=trend,
dollar_volume=dollar_volume,
price_change=price_change,
min_dollar_volume=min_dollar_volume,
)
except (KeyError, ValueError, RuntimeError) as e:
logger.warning("Failed to calculate volume metrics for %s: %s", ticker, str(e))
return result

View File

@ -1,7 +1,12 @@
from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from tradingagents.agents.discovery.quantitative_models import QuantitativeMetrics
class DiscoveryStatus(Enum): class DiscoveryStatus(Enum):
@ -50,7 +55,7 @@ class NewsArticle:
} }
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> "NewsArticle": def from_dict(cls, data: dict[str, Any]) -> NewsArticle:
return cls( return cls(
title=data["title"], title=data["title"],
source=data["source"], source=data["source"],
@ -72,9 +77,11 @@ class TrendingStock:
event_type: EventCategory event_type: EventCategory
news_summary: str news_summary: str
source_articles: list[NewsArticle] source_articles: list[NewsArticle]
quantitative_metrics: QuantitativeMetrics | None = None
conviction_score: float | None = None
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return { result = {
"ticker": self.ticker, "ticker": self.ticker,
"company_name": self.company_name, "company_name": self.company_name,
"score": self.score, "score": self.score,
@ -85,9 +92,24 @@ class TrendingStock:
"news_summary": self.news_summary, "news_summary": self.news_summary,
"source_articles": [article.to_dict() for article in self.source_articles], "source_articles": [article.to_dict() for article in self.source_articles],
} }
if self.quantitative_metrics is not None:
result["quantitative_metrics"] = self.quantitative_metrics.to_dict()
if self.conviction_score is not None:
result["conviction_score"] = self.conviction_score
return result
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> "TrendingStock": def from_dict(cls, data: dict[str, Any]) -> TrendingStock:
from tradingagents.agents.discovery.quantitative_models import (
QuantitativeMetrics,
)
quantitative_metrics = None
if data.get("quantitative_metrics"):
quantitative_metrics = QuantitativeMetrics.from_dict(
data["quantitative_metrics"]
)
return cls( return cls(
ticker=data["ticker"], ticker=data["ticker"],
company_name=data["company_name"], company_name=data["company_name"],
@ -100,6 +122,8 @@ class TrendingStock:
source_articles=[ source_articles=[
NewsArticle.from_dict(article) for article in data["source_articles"] NewsArticle.from_dict(article) for article in data["source_articles"]
], ],
quantitative_metrics=quantitative_metrics,
conviction_score=data.get("conviction_score"),
) )
@ -125,7 +149,7 @@ class DiscoveryRequest:
} }
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> "DiscoveryRequest": def from_dict(cls, data: dict[str, Any]) -> DiscoveryRequest:
return cls( return cls(
lookback_period=data["lookback_period"], lookback_period=data["lookback_period"],
sector_filter=( sector_filter=(
@ -165,7 +189,7 @@ class DiscoveryResult:
} }
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> "DiscoveryResult": def from_dict(cls, data: dict[str, Any]) -> DiscoveryResult:
return cls( return cls(
request=DiscoveryRequest.from_dict(data["request"]), request=DiscoveryRequest.from_dict(data["request"]),
trending_stocks=[ trending_stocks=[

View File

@ -0,0 +1,47 @@
import logging
from collections import OrderedDict
from datetime import datetime
from typing import Any
import pandas as pd
logger = logging.getLogger(__name__)
MAX_CACHE_SIZE = 100
_price_data_cache: OrderedDict[str, tuple[pd.DataFrame, datetime]] = OrderedDict()
def get_cached_price_data(ticker: str) -> pd.DataFrame | None:
if ticker not in _price_data_cache:
return None
df, timestamp = _price_data_cache[ticker]
_price_data_cache.move_to_end(ticker)
return df.copy()
def set_cached_price_data(ticker: str, data: pd.DataFrame) -> None:
global _price_data_cache
if ticker in _price_data_cache:
_price_data_cache.move_to_end(ticker)
_price_data_cache[ticker] = (data.copy(), datetime.now())
return
while len(_price_data_cache) >= MAX_CACHE_SIZE:
_price_data_cache.popitem(last=False)
_price_data_cache[ticker] = (data.copy(), datetime.now())
logger.debug(
"Cached price data for %s (cache size: %d)", ticker, len(_price_data_cache)
)
def clear_run_cache() -> None:
global _price_data_cache
count = len(_price_data_cache)
_price_data_cache.clear()
logger.debug("Cleared run cache (%d entries removed)", count)

View File

@ -0,0 +1,91 @@
from typing import Any
from pydantic import BaseModel, Field, field_validator
class QuantitativeMetrics(BaseModel):
momentum_score: float = Field(ge=0.0, le=1.0)
volume_score: float = Field(ge=0.0, le=1.0)
relative_strength_score: float = Field(ge=0.0, le=1.0)
risk_reward_score: float = Field(ge=0.0, le=1.0)
rsi: float | None = None
macd: float | None = None
macd_signal: float | None = None
macd_histogram: float | None = None
price_vs_sma50: float | None = None
price_vs_sma200: float | None = None
ema10_direction: str | None = None
volume_ratio: float | None = None
volume_trend: str | None = None
dollar_volume: float | None = None
rs_vs_spy_5d: float | None = None
rs_vs_spy_20d: float | None = None
rs_vs_spy_60d: float | None = None
rs_vs_sector: float | None = None
sector_etf: str | None = None
support_level: float | None = None
resistance_level: float | None = None
atr: float | None = None
suggested_stop: float | None = None
reward_target: float | None = None
risk_reward_ratio: float | None = None
timeframe_alignment: str | None = None
short_term_signal: str | None = None
medium_term_signal: str | None = None
long_term_signal: str | None = None
signal_strength: float | None = None
quantitative_score: float = Field(ge=0.0, le=1.0)
@field_validator("ema10_direction")
@classmethod
def validate_ema10_direction(cls, v: str | None) -> str | None:
if v is None:
return v
valid_directions = {"up", "down", "flat"}
if v not in valid_directions:
raise ValueError(f"ema10_direction must be one of {valid_directions}")
return v
@field_validator("volume_trend")
@classmethod
def validate_volume_trend(cls, v: str | None) -> str | None:
if v is None:
return v
valid_trends = {"increasing", "decreasing", "flat"}
if v not in valid_trends:
raise ValueError(f"volume_trend must be one of {valid_trends}")
return v
@field_validator("timeframe_alignment")
@classmethod
def validate_timeframe_alignment(cls, v: str | None) -> str | None:
if v is None:
return v
valid_alignments = {"aligned_bullish", "aligned_bearish", "mixed", "neutral"}
if v not in valid_alignments:
raise ValueError(f"timeframe_alignment must be one of {valid_alignments}")
return v
@field_validator("short_term_signal", "medium_term_signal", "long_term_signal")
@classmethod
def validate_signal(cls, v: str | None) -> str | None:
if v is None:
return v
valid_signals = {"bullish", "bearish", "neutral"}
if v not in valid_signals:
raise ValueError(f"signal must be one of {valid_signals}")
return v
def to_dict(self) -> dict[str, Any]:
return self.model_dump()
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "QuantitativeMetrics":
return cls(**data)

View File

@ -0,0 +1,178 @@
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import TYPE_CHECKING
from tradingagents.agents.discovery.indicators import (
calculate_momentum_score,
calculate_relative_strength_metrics,
calculate_support_resistance_metrics,
calculate_volume_metrics,
)
from tradingagents.agents.discovery.indicators.timeframe import (
calculate_timeframe_signals,
)
from tradingagents.agents.discovery.quantitative_cache import clear_run_cache
from tradingagents.agents.discovery.quantitative_models import QuantitativeMetrics
from tradingagents.config import QuantitativeWeightsConfig, get_settings
if TYPE_CHECKING:
from tradingagents.agents.discovery.models import TrendingStock
logger = logging.getLogger(__name__)
def calculate_unified_score(
momentum: float,
volume: float,
rs: float,
rr: float,
weights: QuantitativeWeightsConfig,
) -> float:
score = (
momentum * weights.momentum_weight
+ volume * weights.volume_weight
+ rs * weights.relative_strength_weight
+ rr * weights.risk_reward_weight
)
return max(0.0, min(1.0, score))
def calculate_single_stock_metrics(
ticker: str,
curr_date: str,
) -> QuantitativeMetrics | None:
try:
settings = get_settings()
weights = settings.quantitative_weights
momentum = calculate_momentum_score(ticker, curr_date)
volume = calculate_volume_metrics(ticker, curr_date)
rs = calculate_relative_strength_metrics(ticker, curr_date)
sr = calculate_support_resistance_metrics(ticker, curr_date)
timeframe = calculate_timeframe_signals(momentum, rs)
momentum_score = momentum.get("momentum_score", 0.5)
volume_score = volume.get("volume_score", 0.5)
rs_score = rs.get("relative_strength_score", 0.5)
rr_score = sr.get("risk_reward_score", 0.5)
unified_score = calculate_unified_score(
momentum=momentum_score,
volume=volume_score,
rs=rs_score,
rr=rr_score,
weights=weights,
)
return QuantitativeMetrics(
momentum_score=momentum_score,
volume_score=volume_score,
relative_strength_score=rs_score,
risk_reward_score=rr_score,
rsi=momentum.get("rsi"),
macd=momentum.get("macd"),
macd_signal=momentum.get("macd_signal"),
macd_histogram=momentum.get("macd_histogram"),
price_vs_sma50=momentum.get("price_vs_sma50"),
price_vs_sma200=momentum.get("price_vs_sma200"),
ema10_direction=momentum.get("ema10_direction"),
volume_ratio=volume.get("volume_ratio"),
volume_trend=volume.get("volume_trend"),
dollar_volume=volume.get("dollar_volume"),
rs_vs_spy_5d=rs.get("rs_vs_spy_5d"),
rs_vs_spy_20d=rs.get("rs_vs_spy_20d"),
rs_vs_spy_60d=rs.get("rs_vs_spy_60d"),
rs_vs_sector=rs.get("rs_vs_sector"),
sector_etf=rs.get("sector_etf"),
support_level=sr.get("support_level"),
resistance_level=sr.get("resistance_level"),
atr=sr.get("atr"),
suggested_stop=sr.get("suggested_stop"),
reward_target=sr.get("reward_target"),
risk_reward_ratio=sr.get("risk_reward_ratio"),
timeframe_alignment=timeframe.get("timeframe_alignment"),
short_term_signal=timeframe.get("short_term_signal"),
medium_term_signal=timeframe.get("medium_term_signal"),
long_term_signal=timeframe.get("long_term_signal"),
signal_strength=timeframe.get("signal_strength"),
quantitative_score=unified_score,
)
except Exception as e:
logger.warning(
"Failed to calculate quantitative metrics for %s: %s", ticker, str(e)
)
return None
def _normalize_score(score: float, max_score: float) -> float:
if max_score == 0:
return 0.0
return min(1.0, score / max_score)
def enhance_with_quantitative_scores(
stocks: list["TrendingStock"],
curr_date: str,
max_stocks: int = 50,
) -> list["TrendingStock"]:
if not stocks:
return []
settings = get_settings()
weights = settings.quantitative_weights
sorted_stocks = sorted(stocks, key=lambda s: s.score, reverse=True)
stocks_to_process = sorted_stocks[:max_stocks]
remaining_stocks = sorted_stocks[max_stocks:]
clear_run_cache()
results: dict[str, QuantitativeMetrics | None] = {}
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_ticker = {
executor.submit(
calculate_single_stock_metrics, stock.ticker, curr_date
): stock.ticker
for stock in stocks_to_process
}
for future in as_completed(future_to_ticker):
ticker = future_to_ticker[future]
try:
results[ticker] = future.result()
except Exception as e:
logger.warning("Quantitative scoring failed for %s: %s", ticker, str(e))
results[ticker] = None
max_news_score = max((s.score for s in stocks_to_process), default=1.0) or 1.0
for stock in stocks_to_process:
metrics = results.get(stock.ticker)
stock.quantitative_metrics = metrics
news_normalized = _normalize_score(stock.score, max_news_score)
if metrics is not None:
stock.conviction_score = (
weights.news_sentiment_weight * news_normalized
+ weights.quantitative_weight * metrics.quantitative_score
)
else:
stock.conviction_score = news_normalized * weights.news_sentiment_weight
for stock in remaining_stocks:
stock.quantitative_metrics = None
stock.conviction_score = None
enhanced_stocks = sorted(
stocks_to_process,
key=lambda s: s.conviction_score if s.conviction_score is not None else 0.0,
reverse=True,
)
clear_run_cache()
return enhanced_stocks + remaining_stocks

View File

@ -1,7 +1,7 @@
import os import os
from typing import Any, Dict, List, Optional from typing import Any
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -12,6 +12,39 @@ class DataVendorsConfig(BaseModel):
news_data: str = "alpha_vantage" news_data: str = "alpha_vantage"
class QuantitativeWeightsConfig(BaseModel):
news_sentiment_weight: float = Field(default=0.50, ge=0.0, le=1.0)
quantitative_weight: float = Field(default=0.50, ge=0.0, le=1.0)
momentum_weight: float = Field(default=0.30, ge=0.0, le=1.0)
volume_weight: float = Field(default=0.25, ge=0.0, le=1.0)
relative_strength_weight: float = Field(default=0.25, ge=0.0, le=1.0)
risk_reward_weight: float = Field(default=0.20, ge=0.0, le=1.0)
@model_validator(mode="after")
def validate_weights_sum(self) -> "QuantitativeWeightsConfig":
top_level_sum = self.news_sentiment_weight + self.quantitative_weight
if abs(top_level_sum - 1.0) > 0.01:
raise ValueError(
f"Top-level weights (news_sentiment_weight + quantitative_weight) "
f"must sum to 1.0, got {top_level_sum}"
)
sub_weights_sum = (
self.momentum_weight
+ self.volume_weight
+ self.relative_strength_weight
+ self.risk_reward_weight
)
if abs(sub_weights_sum - 1.0) > 0.01:
raise ValueError(
f"Sub-weights (momentum + volume + relative_strength + risk_reward) "
f"must sum to 1.0, got {sub_weights_sum}"
)
return self
class TradingAgentsSettings(BaseSettings): class TradingAgentsSettings(BaseSettings):
project_dir: str = Field( project_dir: str = Field(
default_factory=lambda: os.path.abspath( default_factory=lambda: os.path.abspath(
@ -58,6 +91,14 @@ class TradingAgentsSettings(BaseSettings):
data_vendors: DataVendorsConfig = Field(default_factory=DataVendorsConfig) data_vendors: DataVendorsConfig = Field(default_factory=DataVendorsConfig)
tool_vendors: dict[str, Any] = Field(default_factory=dict) tool_vendors: dict[str, Any] = Field(default_factory=dict)
quantitative_weights: QuantitativeWeightsConfig = Field(
default_factory=QuantitativeWeightsConfig
)
quantitative_max_stocks: int = Field(default=50, ge=10, le=100)
quantitative_cache_ttl_intraday: int = Field(default=1, ge=1)
quantitative_cache_ttl_relative_strength: int = Field(default=4, ge=1)
min_dollar_volume: float = Field(default=1_000_000.0, ge=0.0)
model_config = { model_config = {
"env_prefix": "TRADINGAGENTS_", "env_prefix": "TRADINGAGENTS_",
"env_nested_delimiter": "__", "env_nested_delimiter": "__",
@ -104,6 +145,7 @@ class TradingAgentsSettings(BaseSettings):
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
result = self.model_dump() result = self.model_dump()
result["data_vendors"] = self.data_vendors.model_dump() result["data_vendors"] = self.data_vendors.model_dump()
result["quantitative_weights"] = self.quantitative_weights.model_dump()
return result return result
def get_api_key(self, vendor: str) -> str | None: def get_api_key(self, vendor: str) -> str | None:

View File

@ -132,6 +132,11 @@ DEFAULT_TTL_HOURS = {
"get_insider_sentiment": 24, "get_insider_sentiment": 24,
"get_insider_transactions": 24, "get_insider_transactions": 24,
"get_bulk_news": 1, "get_bulk_news": 1,
"quant_indicators": 1,
"volume_analysis": 1,
"relative_strength": 4,
"support_resistance": 1,
"risk_reward": 1,
} }

View File

@ -22,6 +22,9 @@ from tradingagents.agents.discovery import (
calculate_trending_scores, calculate_trending_scores,
extract_entities, extract_entities,
) )
from tradingagents.agents.discovery.quantitative_scorer import (
enhance_with_quantitative_scores,
)
from tradingagents.agents.utils.agent_utils import ( from tradingagents.agents.utils.agent_utils import (
get_balance_sheet, get_balance_sheet,
get_cashflow, get_cashflow,
@ -303,12 +306,17 @@ class TradingAgentsGraph:
) )
hard_timeout = self.config.get("discovery_hard_timeout", 120) hard_timeout = self.config.get("discovery_hard_timeout", 120)
enable_quantitative = self.config.get("enable_quantitative_filtering", True)
quantitative_max_stocks = self.config.get("quantitative_max_stocks", 50)
discovery_result = {"stocks": [], "error": None} discovery_result = {"stocks": [], "error": None}
def run_discovery(): def run_discovery():
try: try:
articles = get_bulk_news(request.lookback_period) articles = get_bulk_news(request.lookback_period)
if not articles:
discovery_result["error"] = "No articles returned from news sources"
return
mentions = extract_entities(articles, self.config) mentions = extract_entities(articles, self.config)
@ -326,6 +334,19 @@ class TradingAgentsGraph:
min_mentions=min_mentions, min_mentions=min_mentions,
) )
if enable_quantitative and trending_stocks:
curr_date = date.today().strftime("%Y-%m-%d")
logger.info(
"Enhancing %d stocks with quantitative scores (max: %d)",
len(trending_stocks),
quantitative_max_stocks,
)
trending_stocks = enhance_with_quantitative_scores(
trending_stocks,
curr_date,
max_stocks=quantitative_max_stocks,
)
discovery_result["stocks"] = trending_stocks discovery_result["stocks"] = trending_stocks
except ( except (
ValueError, ValueError,

View File

@ -5149,6 +5149,7 @@ dependencies = [
{ name = "requests" }, { name = "requests" },
{ name = "rich" }, { name = "rich" },
{ name = "setuptools" }, { name = "setuptools" },
{ name = "sqlalchemy" },
{ name = "stockstats" }, { name = "stockstats" },
{ name = "tqdm" }, { name = "tqdm" },
{ name = "tushare" }, { name = "tushare" },
@ -5196,6 +5197,7 @@ requires-dist = [
{ name = "rich", specifier = ">=14.0.0" }, { name = "rich", specifier = ">=14.0.0" },
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.8.2" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.8.2" },
{ name = "setuptools", specifier = ">=80.9.0" }, { name = "setuptools", specifier = ">=80.9.0" },
{ name = "sqlalchemy", specifier = ">=2.0.0" },
{ name = "stockstats", specifier = ">=0.6.5" }, { name = "stockstats", specifier = ">=0.6.5" },
{ name = "tqdm", specifier = ">=4.67.1" }, { name = "tqdm", specifier = ">=4.67.1" },
{ name = "tushare", specifier = ">=1.4.21" }, { name = "tushare", specifier = ">=1.4.21" },