From 1ba647ae4ac7c2ce05aca9c0abb569b2a7029948 Mon Sep 17 00:00:00 2001 From: Joseph O'Brien <98370624+89jobrien@users.noreply.github.com> Date: Wed, 3 Dec 2025 17:04:18 -0500 Subject: [PATCH] feat: add quantitative scoring with multi-timeframe analysis and CLI enhancements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add quantitative scoring pipeline for discovery with technical indicator analysis: - Momentum, volume, relative strength, and risk/reward scoring - Support/resistance level detection - Gap analysis for price momentum signals - Configurable caching to reduce API calls Implement multi-timeframe signal analysis: - Short-term (5/20 day), medium-term (20/50 day), and long-term (50/200 day) signals - Timeframe alignment detection (aligned_bullish, aligned_bearish, mixed, neutral) - Signal strength calculation based on indicator agreement Enhance CLI discovery display: - Color-coded conviction scores (green/yellow/red thresholds) - Signal column showing timeframe alignment status - News mentions count column Update tests to support new quantitative filtering configuration. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- cli/discovery.py | 136 ++++- tests/discovery/test_api.py | 6 + tests/discovery/test_cli.py | 5 +- tests/discovery/test_integration.py | 43 +- tests/discovery/test_momentum.py | 197 ++++++++ tests/discovery/test_pipeline_integration.py | 344 +++++++++++++ tests/discovery/test_quantitative_cache.py | 101 ++++ tests/discovery/test_quantitative_config.py | 124 +++++ tests/discovery/test_quantitative_gaps.py | 465 ++++++++++++++++++ tests/discovery/test_quantitative_models.py | 267 ++++++++++ tests/discovery/test_quantitative_scorer.py | 312 ++++++++++++ tests/discovery/test_relative_strength.py | 190 +++++++ tests/discovery/test_risk_reward.py | 121 +++++ tests/discovery/test_support_resistance.py | 224 +++++++++ tests/discovery/test_volume.py | 201 ++++++++ tradingagents/agents/discovery/__init__.py | 2 + .../agents/discovery/indicators/__init__.py | 61 +++ .../agents/discovery/indicators/momentum.py | 349 +++++++++++++ .../discovery/indicators/relative_strength.py | 181 +++++++ .../discovery/indicators/risk_reward.py | 34 ++ .../indicators/support_resistance.py | 260 ++++++++++ .../agents/discovery/indicators/timeframe.py | 203 ++++++++ .../agents/discovery/indicators/volume.py | 213 ++++++++ tradingagents/agents/discovery/models.py | 36 +- .../agents/discovery/quantitative_cache.py | 47 ++ .../agents/discovery/quantitative_models.py | 91 ++++ .../agents/discovery/quantitative_scorer.py | 178 +++++++ tradingagents/config.py | 46 +- .../database/services/market_data.py | 5 + tradingagents/graph/trading_graph.py | 21 + uv.lock | 2 + 31 files changed, 4441 insertions(+), 24 deletions(-) create mode 100644 tests/discovery/test_momentum.py create mode 100644 tests/discovery/test_pipeline_integration.py create mode 100644 tests/discovery/test_quantitative_cache.py create mode 100644 tests/discovery/test_quantitative_config.py create mode 100644 tests/discovery/test_quantitative_gaps.py create mode 100644 tests/discovery/test_quantitative_models.py create mode 100644 tests/discovery/test_quantitative_scorer.py create mode 100644 tests/discovery/test_relative_strength.py create mode 100644 tests/discovery/test_risk_reward.py create mode 100644 tests/discovery/test_support_resistance.py create mode 100644 tests/discovery/test_volume.py create mode 100644 tradingagents/agents/discovery/indicators/__init__.py create mode 100644 tradingagents/agents/discovery/indicators/momentum.py create mode 100644 tradingagents/agents/discovery/indicators/relative_strength.py create mode 100644 tradingagents/agents/discovery/indicators/risk_reward.py create mode 100644 tradingagents/agents/discovery/indicators/support_resistance.py create mode 100644 tradingagents/agents/discovery/indicators/timeframe.py create mode 100644 tradingagents/agents/discovery/indicators/volume.py create mode 100644 tradingagents/agents/discovery/quantitative_cache.py create mode 100644 tradingagents/agents/discovery/quantitative_models.py create mode 100644 tradingagents/agents/discovery/quantitative_scorer.py diff --git a/cli/discovery.py b/cli/discovery.py index 8205dd6c..ff54b9a1 100644 --- a/cli/discovery.py +++ b/cli/discovery.py @@ -153,6 +153,35 @@ def select_event_filter() -> list[EventCategory] | None: return choices +def _get_conviction_display(stock: TrendingStock) -> tuple[str, str]: + if stock.conviction_score is None: + return "-", "dim" + score = stock.conviction_score + if score >= 0.7: + return f"{score:.2f}", "bold green" + elif score >= 0.5: + return f"{score:.2f}", "yellow" + else: + return f"{score:.2f}", "red" + + +def _get_signal_display(stock: TrendingStock) -> str: + if stock.quantitative_metrics is None: + return "-" + alignment = stock.quantitative_metrics.timeframe_alignment + if alignment == "aligned_bullish": + return "[bold green]+++[/bold green]" + elif alignment == "aligned_bearish": + return "[bold red]---[/bold red]" + elif alignment == "mixed": + strength = stock.quantitative_metrics.signal_strength or 0.5 + if strength > 0.5: + return "[yellow]++[/yellow]" + else: + return "[yellow]--[/yellow]" + return "[dim]~[/dim]" + + def create_discovery_results_table(trending_stocks: list[TrendingStock]) -> Table: table = Table( show_header=True, @@ -164,11 +193,12 @@ def create_discovery_results_table(trending_stocks: list[TrendingStock]) -> Tabl ) table.add_column("Rank", style="cyan", justify="center", width=6) - table.add_column("Ticker", style="bold yellow", justify="center", width=10) - table.add_column("Company", style="white", justify="left", width=25) - table.add_column("Score", style="green", justify="right", width=10) - table.add_column("Mentions", style="blue", justify="center", width=10) - table.add_column("Event Type", style="magenta", justify="center", width=18) + table.add_column("Ticker", style="bold yellow", justify="center", width=8) + table.add_column("Company", style="white", justify="left", width=20) + table.add_column("Conv.", justify="right", width=6) + table.add_column("Signal", justify="center", width=7) + table.add_column("News", style="blue", justify="right", width=6) + table.add_column("Event Type", style="magenta", justify="center", width=15) for rank, stock in enumerate(trending_stocks, 1): if rank <= 3: @@ -178,20 +208,54 @@ def create_discovery_results_table(trending_stocks: list[TrendingStock]) -> Tabl rank_display = str(rank) ticker_display = stock.ticker + conviction_text, conviction_style = _get_conviction_display(stock) + signal_display = _get_signal_display(stock) + table.add_row( rank_display, ticker_display, - stock.company_name[:25] - if len(stock.company_name) > 25 + stock.company_name[:20] + if len(stock.company_name) > 20 else stock.company_name, - f"{stock.score:.2f}", - str(stock.mention_count), - stock.event_type.value.replace("_", " ").title(), + f"[{conviction_style}]{conviction_text}[/{conviction_style}]", + signal_display, + f"{stock.score:.1f}", + stock.event_type.value.replace("_", " ").title()[:15], ) return table +def _format_timeframe_signals(stock: TrendingStock) -> str: + if stock.quantitative_metrics is None: + return "[dim]No quantitative data available[/dim]" + + qm = stock.quantitative_metrics + short_color = ( + "green" + if qm.short_term_signal == "bullish" + else "red" + if qm.short_term_signal == "bearish" + else "yellow" + ) + med_color = ( + "green" + if qm.medium_term_signal == "bullish" + else "red" + if qm.medium_term_signal == "bearish" + else "yellow" + ) + long_color = ( + "green" + if qm.long_term_signal == "bullish" + else "red" + if qm.long_term_signal == "bearish" + else "yellow" + ) + + return f"[{short_color}]Short: {(qm.short_term_signal or 'N/A').upper()}[/{short_color}] | [{med_color}]Med: {(qm.medium_term_signal or 'N/A').upper()}[/{med_color}] | [{long_color}]Long: {(qm.long_term_signal or 'N/A').upper()}[/{long_color}]" + + def create_stock_detail_panel(stock: TrendingStock, rank: int) -> Panel: sentiment_label = ( "positive" @@ -208,14 +272,64 @@ def create_stock_detail_panel(stock: TrendingStock, rank: int) -> Panel: else "yellow" ) + conviction_text = ( + f"{stock.conviction_score:.2f}" if stock.conviction_score is not None else "N/A" + ) + conviction_color = ( + "green" + if stock.conviction_score and stock.conviction_score >= 0.7 + else "yellow" + if stock.conviction_score and stock.conviction_score >= 0.5 + else "red" + ) + content = f"""[bold]Rank #{rank}: {stock.ticker} - {stock.company_name}[/bold] -[cyan]Score:[/cyan] {stock.score:.2f} +[cyan]Conviction Score:[/cyan] [{conviction_color}]{conviction_text}[/{conviction_color}] +[cyan]News Score:[/cyan] {stock.score:.2f} [cyan]Sentiment:[/cyan] [{sentiment_color}]{stock.sentiment:.2f} ({sentiment_label})[/{sentiment_color}] [cyan]Sector:[/cyan] {stock.sector.value.replace("_", " ").title()} [cyan]Event Type:[/cyan] {stock.event_type.value.replace("_", " ").title()} [cyan]Mentions:[/cyan] {stock.mention_count} +[bold]Timeframe Signals:[/bold] +{_format_timeframe_signals(stock)}""" + + if stock.quantitative_metrics is not None: + qm = stock.quantitative_metrics + alignment_color = ( + "green" + if qm.timeframe_alignment == "aligned_bullish" + else "red" + if qm.timeframe_alignment == "aligned_bearish" + else "yellow" + ) + content += f""" + +[bold]Quantitative Metrics:[/bold] +[cyan]Timeframe Alignment:[/cyan] [{alignment_color}]{(qm.timeframe_alignment or 'N/A').replace('_', ' ').upper()}[/{alignment_color}] +[cyan]Momentum Score:[/cyan] {qm.momentum_score:.2f} [cyan]Volume Score:[/cyan] {qm.volume_score:.2f} +[cyan]Relative Strength:[/cyan] {qm.relative_strength_score:.2f} [cyan]Risk/Reward:[/cyan] {qm.risk_reward_score:.2f}""" + + if qm.rsi is not None: + rsi_color = "green" if qm.rsi < 35 else "red" if qm.rsi > 65 else "yellow" + content += f"\n[cyan]RSI:[/cyan] [{rsi_color}]{qm.rsi:.1f}[/{rsi_color}]" + + if qm.support_level is not None and qm.resistance_level is not None: + content += f" [cyan]Support:[/cyan] ${qm.support_level:.2f} [cyan]Resistance:[/cyan] ${qm.resistance_level:.2f}" + + if qm.risk_reward_ratio is not None: + rr_color = ( + "green" + if qm.risk_reward_ratio >= 2.0 + else "yellow" + if qm.risk_reward_ratio >= 1.0 + else "red" + ) + content += f"\n[cyan]Risk/Reward Ratio:[/cyan] [{rr_color}]{qm.risk_reward_ratio:.2f}:1[/{rr_color}]" + + content += f""" + [bold]News Summary:[/bold] {stock.news_summary} diff --git a/tests/discovery/test_api.py b/tests/discovery/test_api.py index 9d46e8f9..73fdc146 100644 --- a/tests/discovery/test_api.py +++ b/tests/discovery/test_api.py @@ -67,7 +67,9 @@ class TestDiscoverTrendingReturnsDiscoveryResult: "discovery_cache_ttl": 300, "discovery_max_results": 20, "discovery_min_mentions": 2, + "enable_quantitative_filtering": False, } + graph.db_enabled = False result = graph.discover_trending() @@ -118,7 +120,9 @@ class TestSectorFilterParameter: "discovery_cache_ttl": 300, "discovery_max_results": 20, "discovery_min_mentions": 2, + "enable_quantitative_filtering": False, } + graph.db_enabled = False request = DiscoveryRequest( lookback_period="24h", @@ -162,7 +166,9 @@ class TestEventFilterParameter: "discovery_cache_ttl": 300, "discovery_max_results": 20, "discovery_min_mentions": 2, + "enable_quantitative_filtering": False, } + graph.db_enabled = False request = DiscoveryRequest( lookback_period="24h", diff --git a/tests/discovery/test_cli.py b/tests/discovery/test_cli.py index 131497e8..91e3c718 100644 --- a/tests/discovery/test_cli.py +++ b/tests/discovery/test_cli.py @@ -124,8 +124,9 @@ class TestResultsTableDisplay: "Rank", "Ticker", "Company", - "Score", - "Mentions", + "Conv.", + "Signal", + "News", "Event Type", ] for expected in expected_columns: diff --git a/tests/discovery/test_integration.py b/tests/discovery/test_integration.py index f82d8a10..db402239 100644 --- a/tests/discovery/test_integration.py +++ b/tests/discovery/test_integration.py @@ -89,7 +89,9 @@ class TestEndToEndDiscoveryFlow: "discovery_cache_ttl": 300, "discovery_max_results": 20, "discovery_min_mentions": 2, + "enable_quantitative_filtering": False, } + graph.db_enabled = False request = DiscoveryRequest(lookback_period="24h") result = graph.discover_trending(request) @@ -259,7 +261,9 @@ class TestNoTrendingStocksFound: "discovery_cache_ttl": 300, "discovery_max_results": 20, "discovery_min_mentions": 2, + "enable_quantitative_filtering": False, } + graph.db_enabled = False result = graph.discover_trending() @@ -274,7 +278,16 @@ class TestAllStocksFilteredOutBySectorFilter: def test_all_stocks_filtered_out_by_sector_filter( self, mock_scores, mock_extract, mock_bulk_news ): - mock_bulk_news.return_value = [] + mock_bulk_news.return_value = [ + NewsArticle( + title="Test article", + source="Test", + url="https://test.com", + published_at=datetime.now(), + content_snippet="Test content", + ticker_mentions=["AAPL"], + ) + ] mock_extract.return_value = [] mock_scores.return_value = [ TrendingStock( @@ -311,7 +324,9 @@ class TestAllStocksFilteredOutBySectorFilter: "discovery_cache_ttl": 300, "discovery_max_results": 20, "discovery_min_mentions": 2, + "enable_quantitative_filtering": False, } + graph.db_enabled = False request = DiscoveryRequest( lookback_period="24h", @@ -330,7 +345,16 @@ class TestAllStocksFilteredOutByEventFilter: def test_all_stocks_filtered_out_by_event_filter( self, mock_scores, mock_extract, mock_bulk_news ): - mock_bulk_news.return_value = [] + mock_bulk_news.return_value = [ + NewsArticle( + title="Test article", + source="Test", + url="https://test.com", + published_at=datetime.now(), + content_snippet="Test content", + ticker_mentions=["AAPL"], + ) + ] mock_extract.return_value = [] mock_scores.return_value = [ TrendingStock( @@ -356,7 +380,9 @@ class TestAllStocksFilteredOutByEventFilter: "discovery_cache_ttl": 300, "discovery_max_results": 20, "discovery_min_mentions": 2, + "enable_quantitative_filtering": False, } + graph.db_enabled = False request = DiscoveryRequest( lookback_period="24h", @@ -375,7 +401,16 @@ class TestMultipleSectorsAndEventsFiltering: def test_combined_sector_and_event_filtering( self, mock_scores, mock_extract, mock_bulk_news ): - mock_bulk_news.return_value = [] + mock_bulk_news.return_value = [ + NewsArticle( + title="Test article", + source="Test", + url="https://test.com", + published_at=datetime.now(), + content_snippet="Test content", + ticker_mentions=["AAPL"], + ) + ] mock_extract.return_value = [] mock_scores.return_value = [ TrendingStock( @@ -423,7 +458,9 @@ class TestMultipleSectorsAndEventsFiltering: "discovery_cache_ttl": 300, "discovery_max_results": 20, "discovery_min_mentions": 2, + "enable_quantitative_filtering": False, } + graph.db_enabled = False request = DiscoveryRequest( lookback_period="24h", diff --git a/tests/discovery/test_momentum.py b/tests/discovery/test_momentum.py new file mode 100644 index 00000000..e8039623 --- /dev/null +++ b/tests/discovery/test_momentum.py @@ -0,0 +1,197 @@ +from unittest.mock import MagicMock, patch + +import pytest + + +class TestCalculateRsiScore: + def test_rsi_oversold_bullish_score(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_rsi_score, + ) + + score = calculate_rsi_score(25.0) + assert 0.8 <= score <= 1.0 + + score = calculate_rsi_score(30.0) + assert 0.7 <= score <= 0.9 + + def test_rsi_neutral_score(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_rsi_score, + ) + + score = calculate_rsi_score(50.0) + assert 0.5 <= score <= 0.7 + + score = calculate_rsi_score(40.0) + assert 0.5 <= score <= 0.7 + + def test_rsi_overbought_warning_score(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_rsi_score, + ) + + score = calculate_rsi_score(70.0) + assert 0.3 <= score <= 0.5 + + score = calculate_rsi_score(75.0) + assert 0.2 <= score <= 0.5 + + def test_rsi_extreme_overbought_score(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_rsi_score, + ) + + score = calculate_rsi_score(85.0) + assert 0.0 <= score <= 0.3 + + score = calculate_rsi_score(95.0) + assert 0.0 <= score <= 0.2 + + +class TestCalculateMacdScore: + def test_macd_bullish_crossover_high_score(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_macd_score, + ) + + score = calculate_macd_score(macd=1.5, signal=1.0, histogram=0.5) + assert 0.7 <= score <= 1.0 + + def test_macd_bearish_crossover_low_score(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_macd_score, + ) + + score = calculate_macd_score(macd=-1.5, signal=-1.0, histogram=-0.5) + assert 0.0 <= score <= 0.4 + + def test_macd_neutral_score(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_macd_score, + ) + + score = calculate_macd_score(macd=0.1, signal=0.1, histogram=0.0) + assert 0.4 <= score <= 0.6 + + def test_macd_expanding_histogram_bullish(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_macd_score, + ) + + score = calculate_macd_score(macd=2.0, signal=1.5, histogram=0.8) + assert 0.6 <= score <= 1.0 + + +class TestCalculateSmaScore: + def test_price_above_both_smas_high_score(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_sma_score, + ) + + score = calculate_sma_score(price=150.0, sma50=140.0, sma200=130.0) + assert 0.7 <= score <= 1.0 + + def test_price_below_both_smas_low_score(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_sma_score, + ) + + score = calculate_sma_score(price=120.0, sma50=140.0, sma200=150.0) + assert 0.0 <= score <= 0.4 + + def test_price_between_smas_moderate_score(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_sma_score, + ) + + score = calculate_sma_score(price=145.0, sma50=150.0, sma200=130.0) + assert 0.4 <= score <= 0.7 + + def test_golden_alignment_bonus(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_sma_score, + ) + + score = calculate_sma_score(price=160.0, sma50=150.0, sma200=140.0) + assert score >= 0.8 + + +class TestCalculateEmaDirection: + def test_ema_upward_trend(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_ema_direction, + ) + + ema_values = [100.0, 102.0, 104.0, 106.0, 108.0] + direction = calculate_ema_direction(ema_values) + assert direction == "up" + + def test_ema_downward_trend(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_ema_direction, + ) + + ema_values = [108.0, 106.0, 104.0, 102.0, 100.0] + direction = calculate_ema_direction(ema_values) + assert direction == "down" + + def test_ema_flat_trend(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_ema_direction, + ) + + ema_values = [100.0, 100.1, 99.9, 100.0, 100.1] + direction = calculate_ema_direction(ema_values) + assert direction == "flat" + + def test_ema_empty_list_returns_flat(self): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_ema_direction, + ) + + ema_values = [] + direction = calculate_ema_direction(ema_values) + assert direction == "flat" + + +class TestCalculateMomentumScore: + @patch("tradingagents.agents.discovery.indicators.momentum._get_stock_stats_bulk") + def test_calculate_momentum_score_returns_dict(self, mock_get_stats): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_momentum_score, + ) + + mock_get_stats.side_effect = lambda symbol, indicator, date: { + "2024-01-15": "50.0" if indicator == "rsi" else "1.0", + "2024-01-14": "49.0" if indicator == "rsi" else "0.9", + "2024-01-13": "48.0" if indicator == "rsi" else "0.8", + "2024-01-12": "47.0" if indicator == "rsi" else "0.7", + "2024-01-11": "46.0" if indicator == "rsi" else "0.6", + } + + result = calculate_momentum_score("AAPL", "2024-01-15") + + assert isinstance(result, dict) + assert "rsi" in result + assert "macd" in result + assert "macd_signal" in result + assert "macd_histogram" in result + assert "price_vs_sma50" in result + assert "price_vs_sma200" in result + assert "ema10_direction" in result + assert "momentum_score" in result + assert 0.0 <= result["momentum_score"] <= 1.0 + + @patch("tradingagents.agents.discovery.indicators.momentum._get_stock_stats_bulk") + def test_calculate_momentum_score_handles_missing_data(self, mock_get_stats): + from tradingagents.agents.discovery.indicators.momentum import ( + calculate_momentum_score, + ) + + mock_get_stats.return_value = {} + + result = calculate_momentum_score("INVALID", "2024-01-15") + + assert isinstance(result, dict) + assert result["momentum_score"] == 0.5 diff --git a/tests/discovery/test_pipeline_integration.py b/tests/discovery/test_pipeline_integration.py new file mode 100644 index 00000000..89532726 --- /dev/null +++ b/tests/discovery/test_pipeline_integration.py @@ -0,0 +1,344 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + + +class TestDiscoverTrendingIntegration: + @patch("tradingagents.graph.trading_graph.get_bulk_news") + @patch("tradingagents.graph.trading_graph.extract_entities") + @patch("tradingagents.graph.trading_graph.calculate_trending_scores") + @patch("tradingagents.graph.trading_graph.enhance_with_quantitative_scores") + def test_discover_trending_calls_quantitative_enhancement( + self, mock_enhance, mock_scores, mock_extract, mock_bulk_news + ): + from tradingagents.agents.discovery.models import ( + EventCategory, + Sector, + TrendingStock, + ) + from tradingagents.dataflows.models import NewsArticle + from tradingagents.graph.trading_graph import TradingAgentsGraph + + mock_bulk_news.return_value = [ + NewsArticle( + title="Test Article", + source="Test", + url="http://test.com", + published_at=datetime.now(), + content_snippet="Test content", + ticker_mentions=["AAPL"], + ) + ] + mock_extract.return_value = [] + + mock_stock = TrendingStock( + ticker="AAPL", + company_name="Apple Inc", + score=85.0, + mention_count=10, + sentiment=0.7, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Test", + source_articles=[], + ) + mock_scores.return_value = [mock_stock] + mock_enhance.return_value = [mock_stock] + + config = { + "llm_provider": "openai", + "quick_think_llm": "gpt-4o-mini", + "deep_think_llm": "gpt-4o", + "backend_url": "https://api.openai.com/v1", + "project_dir": "/tmp/test", + "database_enabled": False, + "enable_quantitative_filtering": True, + } + + with ( + patch("tradingagents.graph.trading_graph.ChatOpenAI"), + patch("tradingagents.graph.trading_graph.FinancialSituationMemory"), + patch("tradingagents.graph.trading_graph.GraphSetup"), + ): + graph = TradingAgentsGraph(config=config) + result = graph.discover_trending() + + mock_enhance.assert_called_once() + + @patch("tradingagents.graph.trading_graph.get_bulk_news") + @patch("tradingagents.graph.trading_graph.extract_entities") + @patch("tradingagents.graph.trading_graph.calculate_trending_scores") + @patch("tradingagents.graph.trading_graph.enhance_with_quantitative_scores") + def test_discover_trending_skips_quantitative_when_disabled( + self, mock_enhance, mock_scores, mock_extract, mock_bulk_news + ): + from tradingagents.agents.discovery.models import ( + EventCategory, + Sector, + TrendingStock, + ) + from tradingagents.dataflows.models import NewsArticle + from tradingagents.graph.trading_graph import TradingAgentsGraph + + mock_bulk_news.return_value = [ + NewsArticle( + title="Test Article", + source="Test", + url="http://test.com", + published_at=datetime.now(), + content_snippet="Test content", + ticker_mentions=["AAPL"], + ) + ] + mock_extract.return_value = [] + + mock_stock = TrendingStock( + ticker="AAPL", + company_name="Apple Inc", + score=85.0, + mention_count=10, + sentiment=0.7, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Test", + source_articles=[], + ) + mock_scores.return_value = [mock_stock] + + config = { + "llm_provider": "openai", + "quick_think_llm": "gpt-4o-mini", + "deep_think_llm": "gpt-4o", + "backend_url": "https://api.openai.com/v1", + "project_dir": "/tmp/test", + "database_enabled": False, + "enable_quantitative_filtering": False, + } + + with ( + patch("tradingagents.graph.trading_graph.ChatOpenAI"), + patch("tradingagents.graph.trading_graph.FinancialSituationMemory"), + patch("tradingagents.graph.trading_graph.GraphSetup"), + ): + graph = TradingAgentsGraph(config=config) + result = graph.discover_trending() + + mock_enhance.assert_not_called() + + @patch("tradingagents.graph.trading_graph.get_bulk_news") + @patch("tradingagents.graph.trading_graph.extract_entities") + @patch("tradingagents.graph.trading_graph.calculate_trending_scores") + @patch("tradingagents.graph.trading_graph.enhance_with_quantitative_scores") + def test_discover_trending_uses_config_max_stocks( + self, mock_enhance, mock_scores, mock_extract, mock_bulk_news + ): + from tradingagents.agents.discovery.models import ( + EventCategory, + Sector, + TrendingStock, + ) + from tradingagents.dataflows.models import NewsArticle + from tradingagents.graph.trading_graph import TradingAgentsGraph + + mock_bulk_news.return_value = [ + NewsArticle( + title="Test Article", + source="Test", + url="http://test.com", + published_at=datetime.now(), + content_snippet="Test content", + ticker_mentions=["AAPL"], + ) + ] + mock_extract.return_value = [] + + mock_stocks = [ + TrendingStock( + ticker=f"TICK{i}", + company_name=f"Company {i}", + score=100.0 - i, + mention_count=10, + sentiment=0.5, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.OTHER, + news_summary="Test", + source_articles=[], + ) + for i in range(30) + ] + mock_scores.return_value = mock_stocks + mock_enhance.return_value = mock_stocks + + config = { + "llm_provider": "openai", + "quick_think_llm": "gpt-4o-mini", + "deep_think_llm": "gpt-4o", + "backend_url": "https://api.openai.com/v1", + "project_dir": "/tmp/test", + "database_enabled": False, + "enable_quantitative_filtering": True, + "quantitative_max_stocks": 25, + } + + with ( + patch("tradingagents.graph.trading_graph.ChatOpenAI"), + patch("tradingagents.graph.trading_graph.FinancialSituationMemory"), + patch("tradingagents.graph.trading_graph.GraphSetup"), + ): + graph = TradingAgentsGraph(config=config) + result = graph.discover_trending() + + call_args = mock_enhance.call_args + assert call_args[1].get("max_stocks", 50) == 25 + + +class TestScorerConvictionSupport: + def test_calculate_trending_scores_preserves_original_score(self): + from tradingagents.agents.discovery.entity_extractor import EntityMention + from tradingagents.agents.discovery.models import ( + EventCategory, + NewsArticle, + ) + from tradingagents.agents.discovery.scorer import calculate_trending_scores + + mentions = [ + EntityMention( + company_name="Apple", + confidence=0.9, + sentiment=0.7, + event_type=EventCategory.EARNINGS, + context_snippet="Apple reports strong earnings", + article_id="article_0", + ), + EntityMention( + company_name="Apple", + confidence=0.85, + sentiment=0.6, + event_type=EventCategory.EARNINGS, + context_snippet="Apple stock rises", + article_id="article_1", + ), + ] + + articles = [ + NewsArticle( + title="Article 1", + source="Test", + url="http://test.com/1", + published_at=datetime.now(), + content_snippet="Test content 1", + ticker_mentions=["AAPL"], + ), + NewsArticle( + title="Article 2", + source="Test", + url="http://test.com/2", + published_at=datetime.now(), + content_snippet="Test content 2", + ticker_mentions=["AAPL"], + ), + ] + + with patch( + "tradingagents.agents.discovery.scorer.resolve_ticker" + ) as mock_resolve: + mock_resolve.return_value = "AAPL" + result = calculate_trending_scores(mentions, articles) + + assert len(result) == 1 + assert result[0].ticker == "AAPL" + assert result[0].score > 0 + assert result[0].conviction_score is None + + def test_trending_stock_supports_conviction_score(self): + from tradingagents.agents.discovery.models import ( + EventCategory, + Sector, + TrendingStock, + ) + from tradingagents.agents.discovery.quantitative_models import ( + QuantitativeMetrics, + ) + + stock = TrendingStock( + ticker="AAPL", + company_name="Apple Inc", + score=85.0, + mention_count=10, + sentiment=0.7, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Test", + source_articles=[], + ) + + assert stock.conviction_score is None + + stock.conviction_score = 0.85 + assert stock.conviction_score == 0.85 + + metrics = QuantitativeMetrics( + momentum_score=0.7, + volume_score=0.6, + relative_strength_score=0.65, + risk_reward_score=0.7, + quantitative_score=0.66, + ) + stock.quantitative_metrics = metrics + assert stock.quantitative_metrics.quantitative_score == 0.66 + + +class TestBackwardCompatibility: + def test_trending_stock_without_quantitative_fields(self): + from tradingagents.agents.discovery.models import ( + EventCategory, + Sector, + TrendingStock, + ) + + stock = TrendingStock( + ticker="AAPL", + company_name="Apple Inc", + score=85.0, + mention_count=10, + sentiment=0.7, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Test", + source_articles=[], + ) + + assert stock.quantitative_metrics is None + assert stock.conviction_score is None + + stock_dict = stock.to_dict() + assert ( + "quantitative_metrics" not in stock_dict + or stock_dict.get("quantitative_metrics") is None + ) + assert ( + "conviction_score" not in stock_dict + or stock_dict.get("conviction_score") is None + ) + + def test_trending_stock_from_dict_without_quantitative_fields(self): + from tradingagents.agents.discovery.models import TrendingStock + + data = { + "ticker": "AAPL", + "company_name": "Apple Inc", + "score": 85.0, + "mention_count": 10, + "sentiment": 0.7, + "sector": "technology", + "event_type": "earnings", + "news_summary": "Test", + "source_articles": [], + } + + stock = TrendingStock.from_dict(data) + + assert stock.ticker == "AAPL" + assert stock.quantitative_metrics is None + assert stock.conviction_score is None diff --git a/tests/discovery/test_quantitative_cache.py b/tests/discovery/test_quantitative_cache.py new file mode 100644 index 00000000..ebbd548c --- /dev/null +++ b/tests/discovery/test_quantitative_cache.py @@ -0,0 +1,101 @@ +import time +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + + +class TestQuantitativeCache: + def test_cache_set_and_get(self): + from tradingagents.agents.discovery.quantitative_cache import ( + clear_run_cache, + get_cached_price_data, + set_cached_price_data, + ) + + clear_run_cache() + + df = pd.DataFrame( + {"Close": [100.0, 101.0, 102.0], "Volume": [1000, 1100, 1200]} + ) + + set_cached_price_data("AAPL", df) + + cached = get_cached_price_data("AAPL") + + assert cached is not None + assert len(cached) == 3 + assert cached["Close"].tolist() == [100.0, 101.0, 102.0] + + def test_cache_miss_returns_none(self): + from tradingagents.agents.discovery.quantitative_cache import ( + clear_run_cache, + get_cached_price_data, + ) + + clear_run_cache() + + result = get_cached_price_data("NONEXISTENT") + + assert result is None + + def test_cache_clear(self): + from tradingagents.agents.discovery.quantitative_cache import ( + clear_run_cache, + get_cached_price_data, + set_cached_price_data, + ) + + clear_run_cache() + + df = pd.DataFrame({"Close": [100.0]}) + set_cached_price_data("AAPL", df) + + assert get_cached_price_data("AAPL") is not None + + clear_run_cache() + + assert get_cached_price_data("AAPL") is None + + def test_cache_max_size_enforcement(self): + from tradingagents.agents.discovery.quantitative_cache import ( + MAX_CACHE_SIZE, + clear_run_cache, + get_cached_price_data, + set_cached_price_data, + ) + + clear_run_cache() + + for i in range(MAX_CACHE_SIZE + 10): + ticker = f"TICKER{i}" + df = pd.DataFrame({"Close": [float(i)]}) + set_cached_price_data(ticker, df) + + cached_count = 0 + for i in range(MAX_CACHE_SIZE + 10): + ticker = f"TICKER{i}" + if get_cached_price_data(ticker) is not None: + cached_count += 1 + + assert cached_count <= MAX_CACHE_SIZE + + +class TestCacheTTLConstants: + def test_default_ttl_hours_contains_quant_entries(self): + from tradingagents.database.services.market_data import DEFAULT_TTL_HOURS + + assert "quant_indicators" in DEFAULT_TTL_HOURS + assert "volume_analysis" in DEFAULT_TTL_HOURS + assert "relative_strength" in DEFAULT_TTL_HOURS + assert "support_resistance" in DEFAULT_TTL_HOURS + assert "risk_reward" in DEFAULT_TTL_HOURS + + def test_quant_ttl_values(self): + from tradingagents.database.services.market_data import DEFAULT_TTL_HOURS + + assert DEFAULT_TTL_HOURS["quant_indicators"] == 1 + assert DEFAULT_TTL_HOURS["volume_analysis"] == 1 + assert DEFAULT_TTL_HOURS["relative_strength"] == 4 + assert DEFAULT_TTL_HOURS["support_resistance"] == 1 + assert DEFAULT_TTL_HOURS["risk_reward"] == 1 diff --git a/tests/discovery/test_quantitative_config.py b/tests/discovery/test_quantitative_config.py new file mode 100644 index 00000000..7f3cf2aa --- /dev/null +++ b/tests/discovery/test_quantitative_config.py @@ -0,0 +1,124 @@ +import os +from unittest.mock import patch + +import pytest +from pydantic import ValidationError + +from tradingagents.config import ( + QuantitativeWeightsConfig, + TradingAgentsSettings, + get_settings, + reset_settings, +) + + +class TestQuantitativeWeightsConfigDefaults: + def test_default_weight_values_are_set_correctly(self): + config = QuantitativeWeightsConfig() + + assert config.news_sentiment_weight == 0.50 + assert config.quantitative_weight == 0.50 + assert config.momentum_weight == 0.30 + assert config.volume_weight == 0.25 + assert config.relative_strength_weight == 0.25 + assert config.risk_reward_weight == 0.20 + + +class TestQuantitativeWeightsConfigValidation: + def test_top_level_weights_sum_to_one(self): + config = QuantitativeWeightsConfig( + news_sentiment_weight=0.60, + quantitative_weight=0.40, + ) + assert ( + config.news_sentiment_weight + config.quantitative_weight + == pytest.approx(1.0) + ) + + def test_sub_weights_sum_to_one(self): + config = QuantitativeWeightsConfig() + sub_weights_sum = ( + config.momentum_weight + + config.volume_weight + + config.relative_strength_weight + + config.risk_reward_weight + ) + assert sub_weights_sum == pytest.approx(1.0) + + def test_top_level_weights_validation_rejects_invalid_sum(self): + with pytest.raises(ValidationError) as exc_info: + QuantitativeWeightsConfig( + news_sentiment_weight=0.60, + quantitative_weight=0.60, + ) + assert "sum to 1.0" in str(exc_info.value).lower() + + def test_sub_weights_validation_rejects_invalid_sum(self): + with pytest.raises(ValidationError) as exc_info: + QuantitativeWeightsConfig( + momentum_weight=0.50, + volume_weight=0.50, + relative_strength_weight=0.50, + risk_reward_weight=0.50, + ) + assert "sum to 1.0" in str(exc_info.value).lower() + + +class TestQuantitativeWeightsConfigEnvOverride: + def setup_method(self): + reset_settings() + + def teardown_method(self): + reset_settings() + + def test_environment_variable_override_functionality(self): + env_vars = { + "TRADINGAGENTS_QUANTITATIVE_WEIGHTS__NEWS_SENTIMENT_WEIGHT": "0.70", + "TRADINGAGENTS_QUANTITATIVE_WEIGHTS__QUANTITATIVE_WEIGHT": "0.30", + } + with patch.dict(os.environ, env_vars, clear=False): + reset_settings() + settings = get_settings() + + assert settings.quantitative_weights.news_sentiment_weight == pytest.approx( + 0.70 + ) + assert settings.quantitative_weights.quantitative_weight == pytest.approx( + 0.30 + ) + + +class TestQuantitativeSettingsIntegration: + def setup_method(self): + reset_settings() + + def teardown_method(self): + reset_settings() + + def test_quantitative_settings_in_trading_agents_settings(self): + settings = TradingAgentsSettings() + + assert hasattr(settings, "quantitative_weights") + assert isinstance(settings.quantitative_weights, QuantitativeWeightsConfig) + assert hasattr(settings, "quantitative_max_stocks") + assert hasattr(settings, "quantitative_cache_ttl_intraday") + assert hasattr(settings, "quantitative_cache_ttl_relative_strength") + assert hasattr(settings, "min_dollar_volume") + + def test_quantitative_settings_default_values(self): + settings = TradingAgentsSettings() + + assert settings.quantitative_max_stocks == 50 + assert settings.quantitative_cache_ttl_intraday == 1 + assert settings.quantitative_cache_ttl_relative_strength == 4 + assert settings.min_dollar_volume == 1_000_000.0 + + def test_quantitative_max_stocks_bounds(self): + settings = TradingAgentsSettings(quantitative_max_stocks=75) + assert settings.quantitative_max_stocks == 75 + + with pytest.raises(ValidationError): + TradingAgentsSettings(quantitative_max_stocks=5) + + with pytest.raises(ValidationError): + TradingAgentsSettings(quantitative_max_stocks=150) diff --git a/tests/discovery/test_quantitative_gaps.py b/tests/discovery/test_quantitative_gaps.py new file mode 100644 index 00000000..3ae36be6 --- /dev/null +++ b/tests/discovery/test_quantitative_gaps.py @@ -0,0 +1,465 @@ +import threading +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from tradingagents.agents.discovery.models import ( + EventCategory, + NewsArticle, + Sector, + TrendingStock, +) +from tradingagents.agents.discovery.quantitative_models import QuantitativeMetrics +from tradingagents.config import QuantitativeWeightsConfig + + +class TestFullPipelineIntegration: + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score" + ) + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_volume_metrics" + ) + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_relative_strength_metrics" + ) + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_support_resistance_metrics" + ) + def test_full_pipeline_news_to_quantitative_enhancement( + self, mock_sr, mock_rs, mock_vol, mock_mom + ): + from tradingagents.agents.discovery.quantitative_scorer import ( + enhance_with_quantitative_scores, + ) + + mock_mom.return_value = { + "rsi": 45.0, + "macd": 0.5, + "macd_signal": 0.4, + "macd_histogram": 0.1, + "price_vs_sma50": 5.0, + "price_vs_sma200": 10.0, + "ema10_direction": "up", + "momentum_score": 0.7, + } + mock_vol.return_value = { + "volume_ratio": 1.8, + "volume_trend": "increasing", + "dollar_volume": 25000000.0, + "volume_score": 0.75, + } + mock_rs.return_value = { + "rs_vs_spy_5d": 3.0, + "rs_vs_spy_20d": 5.0, + "rs_vs_spy_60d": 8.0, + "rs_vs_sector": 2.5, + "sector_etf": "XLK", + "relative_strength_score": 0.8, + } + mock_sr.return_value = { + "support_level": 145.0, + "resistance_level": 175.0, + "atr": 3.0, + "suggested_stop": 140.5, + "reward_target": 175.0, + "risk_reward_ratio": 2.5, + "risk_reward_score": 0.85, + } + + articles = [ + NewsArticle( + title="Apple Reports Strong Quarter", + source="Reuters", + url="http://example.com/1", + published_at=datetime.now(), + content_snippet="Apple beats expectations", + ticker_mentions=["AAPL"], + ) + ] + + stocks = [ + TrendingStock( + ticker="AAPL", + company_name="Apple Inc", + score=85.0, + mention_count=15, + sentiment=0.75, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Apple reports strong quarterly results", + source_articles=articles, + ), + TrendingStock( + ticker="MSFT", + company_name="Microsoft Corp", + score=72.0, + mention_count=10, + sentiment=0.6, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.PRODUCT_LAUNCH, + news_summary="Microsoft launches new product", + source_articles=[], + ), + ] + + result = enhance_with_quantitative_scores(stocks, "2024-01-15") + + assert len(result) == 2 + for stock in result: + if stock.quantitative_metrics is not None: + assert stock.conviction_score is not None + assert 0.0 <= stock.conviction_score <= 1.0 + assert stock.quantitative_metrics.rsi == 45.0 + assert stock.quantitative_metrics.sector_etf == "XLK" + + +class TestDelistedStockHandling: + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score" + ) + def test_stock_with_no_trading_data_returns_none(self, mock_mom): + from tradingagents.agents.discovery.quantitative_scorer import ( + calculate_single_stock_metrics, + ) + + mock_mom.side_effect = Exception("No price data available for delisted stock") + + result = calculate_single_stock_metrics("DELIST", "2024-01-15") + + assert result is None + + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score" + ) + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_volume_metrics" + ) + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_relative_strength_metrics" + ) + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_support_resistance_metrics" + ) + def test_enhance_continues_after_delisted_stock_failure( + self, mock_sr, mock_rs, mock_vol, mock_mom + ): + from tradingagents.agents.discovery.quantitative_scorer import ( + enhance_with_quantitative_scores, + ) + + def mom_side_effect(ticker, date): + if ticker == "DELIST": + raise Exception("No data for delisted stock") + return {"momentum_score": 0.6, "rsi": 50.0} + + mock_mom.side_effect = mom_side_effect + mock_vol.return_value = {"volume_score": 0.5} + mock_rs.return_value = {"relative_strength_score": 0.5, "sector_etf": "SPY"} + mock_sr.return_value = {"risk_reward_score": 0.5} + + stocks = [ + TrendingStock( + ticker="AAPL", + company_name="Apple Inc", + score=90.0, + mention_count=10, + sentiment=0.5, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.OTHER, + news_summary="Test", + source_articles=[], + ), + TrendingStock( + ticker="DELIST", + company_name="Delisted Corp", + score=80.0, + mention_count=8, + sentiment=0.4, + sector=Sector.OTHER, + event_type=EventCategory.OTHER, + news_summary="Test", + source_articles=[], + ), + ] + + result = enhance_with_quantitative_scores(stocks, "2024-01-15") + + assert len(result) == 2 + + aapl = next((s for s in result if s.ticker == "AAPL"), None) + delist = next((s for s in result if s.ticker == "DELIST"), None) + + assert aapl is not None + assert aapl.quantitative_metrics is not None + + assert delist is not None + assert delist.quantitative_metrics is None + + +class TestWeightConfigurationEdgeCases: + def test_weights_near_boundary_accepted(self): + config = QuantitativeWeightsConfig( + news_sentiment_weight=0.501, + quantitative_weight=0.499, + ) + + total = config.news_sentiment_weight + config.quantitative_weight + assert abs(total - 1.0) < 0.01 + + def test_sub_weights_custom_values_accepted(self): + config = QuantitativeWeightsConfig( + momentum_weight=0.40, + volume_weight=0.20, + relative_strength_weight=0.20, + risk_reward_weight=0.20, + ) + + sub_total = ( + config.momentum_weight + + config.volume_weight + + config.relative_strength_weight + + config.risk_reward_weight + ) + assert abs(sub_total - 1.0) < 0.01 + + +class TestCacheConcurrentAccess: + def test_cache_thread_safety_under_concurrent_writes(self): + from tradingagents.agents.discovery.quantitative_cache import ( + clear_run_cache, + get_cached_price_data, + set_cached_price_data, + ) + + clear_run_cache() + + errors = [] + results = {} + + def write_to_cache(ticker_num): + try: + ticker = f"TICK{ticker_num}" + df = pd.DataFrame({"Close": [float(ticker_num)]}) + set_cached_price_data(ticker, df) + cached = get_cached_price_data(ticker) + if cached is not None: + results[ticker] = cached["Close"].iloc[0] + except Exception as e: + errors.append(str(e)) + + with ThreadPoolExecutor(max_workers=20) as executor: + futures = [executor.submit(write_to_cache, i) for i in range(50)] + for f in futures: + f.result() + + assert len(errors) == 0 + + clear_run_cache() + + +class TestConvictionScoreRanking: + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics" + ) + def test_conviction_score_ranking_accuracy(self, mock_calc): + from tradingagents.agents.discovery.quantitative_scorer import ( + enhance_with_quantitative_scores, + ) + + def side_effect(ticker, date): + scores = { + "HIGH_BOTH": (0.9, 0.9), + "HIGH_NEWS_LOW_QUANT": (0.3, 0.9), + "LOW_NEWS_HIGH_QUANT": (0.9, 0.3), + "LOW_BOTH": (0.3, 0.3), + } + quant, _ = scores.get(ticker, (0.5, 0.5)) + return QuantitativeMetrics( + momentum_score=quant, + volume_score=quant, + relative_strength_score=quant, + risk_reward_score=quant, + quantitative_score=quant, + ) + + mock_calc.side_effect = side_effect + + stocks = [ + TrendingStock( + ticker="LOW_BOTH", + company_name="Low Both", + score=30.0, + mention_count=5, + sentiment=0.3, + sector=Sector.OTHER, + event_type=EventCategory.OTHER, + news_summary="Test", + source_articles=[], + ), + TrendingStock( + ticker="HIGH_BOTH", + company_name="High Both", + score=90.0, + mention_count=15, + sentiment=0.9, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Test", + source_articles=[], + ), + TrendingStock( + ticker="HIGH_NEWS_LOW_QUANT", + company_name="High News Low Quant", + score=90.0, + mention_count=15, + sentiment=0.9, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.OTHER, + news_summary="Test", + source_articles=[], + ), + TrendingStock( + ticker="LOW_NEWS_HIGH_QUANT", + company_name="Low News High Quant", + score=30.0, + mention_count=5, + sentiment=0.3, + sector=Sector.OTHER, + event_type=EventCategory.OTHER, + news_summary="Test", + source_articles=[], + ), + ] + + result = enhance_with_quantitative_scores(stocks, "2024-01-15") + + assert result[0].ticker == "HIGH_BOTH" + + high_both = next(s for s in result if s.ticker == "HIGH_BOTH") + low_both = next(s for s in result if s.ticker == "LOW_BOTH") + assert high_both.conviction_score > low_both.conviction_score + + +class TestTrendingStockSerializationWithQuantitativeMetrics: + def test_trending_stock_with_quantitative_metrics_roundtrip(self): + articles = [ + NewsArticle( + title="Test Article", + source="Test Source", + url="http://test.com", + published_at=datetime(2024, 1, 15, 10, 30, 0), + content_snippet="Test content", + ticker_mentions=["AAPL"], + ) + ] + + metrics = QuantitativeMetrics( + momentum_score=0.75, + volume_score=0.65, + relative_strength_score=0.80, + risk_reward_score=0.70, + rsi=42.5, + macd=0.35, + macd_signal=0.28, + macd_histogram=0.07, + price_vs_sma50=4.5, + price_vs_sma200=9.2, + ema10_direction="up", + volume_ratio=1.65, + volume_trend="increasing", + dollar_volume=18500000.0, + rs_vs_spy_5d=2.8, + rs_vs_spy_20d=4.2, + rs_vs_spy_60d=7.5, + rs_vs_sector=2.1, + sector_etf="XLK", + support_level=148.50, + resistance_level=165.75, + atr=2.85, + suggested_stop=144.22, + reward_target=165.75, + risk_reward_ratio=2.65, + quantitative_score=0.725, + ) + + original = TrendingStock( + ticker="AAPL", + company_name="Apple Inc", + score=87.5, + mention_count=12, + sentiment=0.72, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Apple reports strong quarterly earnings", + source_articles=articles, + quantitative_metrics=metrics, + conviction_score=0.815, + ) + + data = original.to_dict() + restored = TrendingStock.from_dict(data) + + assert restored.ticker == original.ticker + assert restored.score == original.score + assert restored.conviction_score == original.conviction_score + assert restored.quantitative_metrics is not None + assert restored.quantitative_metrics.momentum_score == 0.75 + assert restored.quantitative_metrics.rsi == 42.5 + assert restored.quantitative_metrics.sector_etf == "XLK" + assert restored.quantitative_metrics.quantitative_score == 0.725 + + +class TestModuleIntegration: + def test_unified_score_combines_all_indicators_correctly(self): + from tradingagents.agents.discovery.quantitative_scorer import ( + calculate_unified_score, + ) + + weights = QuantitativeWeightsConfig() + + score = calculate_unified_score( + momentum=0.8, + volume=0.6, + rs=0.7, + rr=0.9, + weights=weights, + ) + + expected = ( + 0.8 * weights.momentum_weight + + 0.6 * weights.volume_weight + + 0.7 * weights.relative_strength_weight + + 0.9 * weights.risk_reward_weight + ) + assert abs(score - expected) < 0.001 + + def test_unified_score_clamped_to_valid_range(self): + from tradingagents.agents.discovery.quantitative_scorer import ( + calculate_unified_score, + ) + + weights = QuantitativeWeightsConfig() + + score = calculate_unified_score( + momentum=1.5, + volume=1.5, + rs=1.5, + rr=1.5, + weights=weights, + ) + + assert score == 1.0 + + score_low = calculate_unified_score( + momentum=-0.5, + volume=-0.5, + rs=-0.5, + rr=-0.5, + weights=weights, + ) + + assert score_low == 0.0 diff --git a/tests/discovery/test_quantitative_models.py b/tests/discovery/test_quantitative_models.py new file mode 100644 index 00000000..fce01dec --- /dev/null +++ b/tests/discovery/test_quantitative_models.py @@ -0,0 +1,267 @@ +import pytest +from pydantic import ValidationError + +from tradingagents.agents.discovery.quantitative_models import QuantitativeMetrics + + +class TestQuantitativeMetricsInstantiation: + def test_model_instantiation_with_valid_data(self): + metrics = QuantitativeMetrics( + momentum_score=0.75, + volume_score=0.60, + relative_strength_score=0.80, + risk_reward_score=0.70, + rsi=45.5, + macd=0.25, + macd_signal=0.20, + macd_histogram=0.05, + price_vs_sma50=5.2, + price_vs_sma200=12.3, + ema10_direction="up", + volume_ratio=1.8, + volume_trend="increasing", + dollar_volume=15_000_000.0, + rs_vs_spy_5d=2.5, + rs_vs_spy_20d=5.0, + rs_vs_spy_60d=8.2, + rs_vs_sector=3.1, + sector_etf="XLK", + support_level=145.50, + resistance_level=162.30, + atr=3.25, + suggested_stop=140.62, + reward_target=162.30, + risk_reward_ratio=2.8, + quantitative_score=0.72, + ) + + assert metrics.momentum_score == 0.75 + assert metrics.volume_score == 0.60 + assert metrics.relative_strength_score == 0.80 + assert metrics.risk_reward_score == 0.70 + assert metrics.rsi == 45.5 + assert metrics.macd == 0.25 + assert metrics.ema10_direction == "up" + assert metrics.sector_etf == "XLK" + assert metrics.quantitative_score == 0.72 + + +class TestQuantitativeMetricsScoreValidation: + def test_score_validation_accepts_valid_range(self): + metrics = QuantitativeMetrics( + momentum_score=0.0, + volume_score=0.5, + relative_strength_score=1.0, + risk_reward_score=0.99, + rsi=50.0, + macd=0.0, + macd_signal=0.0, + macd_histogram=0.0, + price_vs_sma50=0.0, + price_vs_sma200=0.0, + ema10_direction="flat", + volume_ratio=1.0, + volume_trend="flat", + dollar_volume=1_000_000.0, + rs_vs_spy_5d=0.0, + rs_vs_spy_20d=0.0, + rs_vs_spy_60d=0.0, + rs_vs_sector=0.0, + sector_etf="SPY", + support_level=100.0, + resistance_level=110.0, + atr=2.0, + suggested_stop=97.0, + reward_target=110.0, + risk_reward_ratio=3.33, + quantitative_score=0.5, + ) + + assert metrics.momentum_score == 0.0 + assert metrics.relative_strength_score == 1.0 + assert metrics.quantitative_score == 0.5 + + def test_score_validation_rejects_negative_score(self): + with pytest.raises(ValidationError) as exc_info: + QuantitativeMetrics( + momentum_score=-0.1, + volume_score=0.5, + relative_strength_score=0.5, + risk_reward_score=0.5, + rsi=50.0, + macd=0.0, + macd_signal=0.0, + macd_histogram=0.0, + price_vs_sma50=0.0, + price_vs_sma200=0.0, + ema10_direction="flat", + volume_ratio=1.0, + volume_trend="flat", + dollar_volume=1_000_000.0, + rs_vs_spy_5d=0.0, + rs_vs_spy_20d=0.0, + rs_vs_spy_60d=0.0, + rs_vs_sector=0.0, + sector_etf="SPY", + support_level=100.0, + resistance_level=110.0, + atr=2.0, + suggested_stop=97.0, + reward_target=110.0, + risk_reward_ratio=3.33, + quantitative_score=0.5, + ) + assert "momentum_score" in str(exc_info.value) + + def test_score_validation_rejects_above_one(self): + with pytest.raises(ValidationError) as exc_info: + QuantitativeMetrics( + momentum_score=0.5, + volume_score=1.5, + relative_strength_score=0.5, + risk_reward_score=0.5, + rsi=50.0, + macd=0.0, + macd_signal=0.0, + macd_histogram=0.0, + price_vs_sma50=0.0, + price_vs_sma200=0.0, + ema10_direction="flat", + volume_ratio=1.0, + volume_trend="flat", + dollar_volume=1_000_000.0, + rs_vs_spy_5d=0.0, + rs_vs_spy_20d=0.0, + rs_vs_spy_60d=0.0, + rs_vs_sector=0.0, + sector_etf="SPY", + support_level=100.0, + resistance_level=110.0, + atr=2.0, + suggested_stop=97.0, + reward_target=110.0, + risk_reward_ratio=3.33, + quantitative_score=0.5, + ) + assert "volume_score" in str(exc_info.value) + + +class TestQuantitativeMetricsSerialization: + def test_to_dict_and_from_dict_roundtrip(self): + original = QuantitativeMetrics( + momentum_score=0.75, + volume_score=0.60, + relative_strength_score=0.80, + risk_reward_score=0.70, + rsi=45.5, + macd=0.25, + macd_signal=0.20, + macd_histogram=0.05, + price_vs_sma50=5.2, + price_vs_sma200=12.3, + ema10_direction="up", + volume_ratio=1.8, + volume_trend="increasing", + dollar_volume=15_000_000.0, + rs_vs_spy_5d=2.5, + rs_vs_spy_20d=5.0, + rs_vs_spy_60d=8.2, + rs_vs_sector=3.1, + sector_etf="XLK", + support_level=145.50, + resistance_level=162.30, + atr=3.25, + suggested_stop=140.62, + reward_target=162.30, + risk_reward_ratio=2.8, + quantitative_score=0.72, + ) + + data = original.to_dict() + restored = QuantitativeMetrics.from_dict(data) + + assert restored.momentum_score == original.momentum_score + assert restored.volume_score == original.volume_score + assert restored.relative_strength_score == original.relative_strength_score + assert restored.risk_reward_score == original.risk_reward_score + assert restored.rsi == original.rsi + assert restored.macd == original.macd + assert restored.ema10_direction == original.ema10_direction + assert restored.sector_etf == original.sector_etf + assert restored.quantitative_score == original.quantitative_score + + +class TestQuantitativeMetricsOptionalFields: + def test_optional_field_handling_with_none_defaults(self): + metrics = QuantitativeMetrics( + momentum_score=0.5, + volume_score=0.5, + relative_strength_score=0.5, + risk_reward_score=0.5, + rsi=None, + macd=None, + macd_signal=None, + macd_histogram=None, + price_vs_sma50=None, + price_vs_sma200=None, + ema10_direction=None, + volume_ratio=None, + volume_trend=None, + dollar_volume=None, + rs_vs_spy_5d=None, + rs_vs_spy_20d=None, + rs_vs_spy_60d=None, + rs_vs_sector=None, + sector_etf=None, + support_level=None, + resistance_level=None, + atr=None, + suggested_stop=None, + reward_target=None, + risk_reward_ratio=None, + quantitative_score=0.5, + ) + + assert metrics.rsi is None + assert metrics.macd is None + assert metrics.ema10_direction is None + assert metrics.sector_etf is None + assert metrics.momentum_score == 0.5 + assert metrics.quantitative_score == 0.5 + + def test_serialization_with_none_values(self): + original = QuantitativeMetrics( + momentum_score=0.5, + volume_score=0.5, + relative_strength_score=0.5, + risk_reward_score=0.5, + rsi=None, + macd=None, + macd_signal=None, + macd_histogram=None, + price_vs_sma50=None, + price_vs_sma200=None, + ema10_direction=None, + volume_ratio=None, + volume_trend=None, + dollar_volume=None, + rs_vs_spy_5d=None, + rs_vs_spy_20d=None, + rs_vs_spy_60d=None, + rs_vs_sector=None, + sector_etf=None, + support_level=None, + resistance_level=None, + atr=None, + suggested_stop=None, + reward_target=None, + risk_reward_ratio=None, + quantitative_score=0.5, + ) + + data = original.to_dict() + restored = QuantitativeMetrics.from_dict(data) + + assert restored.rsi is None + assert restored.ema10_direction is None + assert restored.momentum_score == 0.5 diff --git a/tests/discovery/test_quantitative_scorer.py b/tests/discovery/test_quantitative_scorer.py new file mode 100644 index 00000000..e44e08fd --- /dev/null +++ b/tests/discovery/test_quantitative_scorer.py @@ -0,0 +1,312 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + + +class TestCalculateSingleStockMetrics: + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score" + ) + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_volume_metrics" + ) + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_relative_strength_metrics" + ) + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_support_resistance_metrics" + ) + def test_calculate_single_stock_metrics_success( + self, mock_sr, mock_rs, mock_vol, mock_mom + ): + from tradingagents.agents.discovery.quantitative_scorer import ( + calculate_single_stock_metrics, + ) + + mock_mom.return_value = { + "rsi": 45.0, + "macd": 0.5, + "macd_signal": 0.4, + "macd_histogram": 0.1, + "price_vs_sma50": 2.5, + "price_vs_sma200": 5.0, + "ema10_direction": "up", + "momentum_score": 0.65, + } + mock_vol.return_value = { + "volume_ratio": 1.5, + "volume_trend": "increasing", + "dollar_volume": 5000000.0, + "volume_score": 0.7, + } + mock_rs.return_value = { + "rs_vs_spy_5d": 2.0, + "rs_vs_spy_20d": 3.5, + "rs_vs_spy_60d": 5.0, + "rs_vs_sector": 2.5, + "sector_etf": "XLK", + "relative_strength_score": 0.6, + } + mock_sr.return_value = { + "support_level": 150.0, + "resistance_level": 180.0, + "atr": 3.5, + "suggested_stop": 145.0, + "reward_target": 180.0, + "risk_reward_ratio": 2.5, + "risk_reward_score": 0.75, + } + + result = calculate_single_stock_metrics("AAPL", "2024-01-15") + + assert result is not None + assert result.momentum_score == 0.65 + assert result.volume_score == 0.7 + assert result.relative_strength_score == 0.6 + assert result.risk_reward_score == 0.75 + assert 0.0 <= result.quantitative_score <= 1.0 + + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_momentum_score" + ) + def test_calculate_single_stock_metrics_failure(self, mock_mom): + from tradingagents.agents.discovery.quantitative_scorer import ( + calculate_single_stock_metrics, + ) + + mock_mom.side_effect = Exception("API failure") + + result = calculate_single_stock_metrics("INVALID", "2024-01-15") + + assert result is None + + +class TestCalculateUnifiedScore: + def test_unified_score_basic(self): + from tradingagents.agents.discovery.quantitative_scorer import ( + calculate_unified_score, + ) + from tradingagents.config import QuantitativeWeightsConfig + + weights = QuantitativeWeightsConfig() + + score = calculate_unified_score( + momentum=0.8, + volume=0.6, + rs=0.7, + rr=0.5, + weights=weights, + ) + + expected = 0.8 * 0.30 + 0.6 * 0.25 + 0.7 * 0.25 + 0.5 * 0.20 + assert abs(score - expected) < 0.001 + + def test_unified_score_all_ones(self): + from tradingagents.agents.discovery.quantitative_scorer import ( + calculate_unified_score, + ) + from tradingagents.config import QuantitativeWeightsConfig + + weights = QuantitativeWeightsConfig() + + score = calculate_unified_score( + momentum=1.0, + volume=1.0, + rs=1.0, + rr=1.0, + weights=weights, + ) + + assert abs(score - 1.0) < 0.001 + + def test_unified_score_all_zeros(self): + from tradingagents.agents.discovery.quantitative_scorer import ( + calculate_unified_score, + ) + from tradingagents.config import QuantitativeWeightsConfig + + weights = QuantitativeWeightsConfig() + + score = calculate_unified_score( + momentum=0.0, + volume=0.0, + rs=0.0, + rr=0.0, + weights=weights, + ) + + assert abs(score - 0.0) < 0.001 + + +class TestEnhanceWithQuantitativeScores: + def _create_mock_trending_stock(self, ticker: str, score: float): + from tradingagents.agents.discovery.models import ( + EventCategory, + NewsArticle, + Sector, + TrendingStock, + ) + + return TrendingStock( + ticker=ticker, + company_name=f"{ticker} Inc", + score=score, + mention_count=10, + sentiment=0.5, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Test summary", + source_articles=[], + ) + + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics" + ) + def test_enhance_with_quantitative_scores_basic(self, mock_calc): + from tradingagents.agents.discovery.quantitative_models import ( + QuantitativeMetrics, + ) + from tradingagents.agents.discovery.quantitative_scorer import ( + enhance_with_quantitative_scores, + ) + + mock_calc.return_value = QuantitativeMetrics( + momentum_score=0.7, + volume_score=0.6, + relative_strength_score=0.65, + risk_reward_score=0.7, + quantitative_score=0.66, + ) + + stocks = [ + self._create_mock_trending_stock("AAPL", 90.0), + self._create_mock_trending_stock("MSFT", 80.0), + ] + + result = enhance_with_quantitative_scores(stocks, "2024-01-15") + + assert len(result) == 2 + for stock in result: + assert stock.quantitative_metrics is not None + assert stock.conviction_score is not None + + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics" + ) + def test_enhance_caps_at_max_stocks(self, mock_calc): + from tradingagents.agents.discovery.quantitative_models import ( + QuantitativeMetrics, + ) + from tradingagents.agents.discovery.quantitative_scorer import ( + enhance_with_quantitative_scores, + ) + + mock_calc.return_value = QuantitativeMetrics( + momentum_score=0.5, + volume_score=0.5, + relative_strength_score=0.5, + risk_reward_score=0.5, + quantitative_score=0.5, + ) + + stocks = [ + self._create_mock_trending_stock(f"TICK{i}", 100.0 - i) for i in range(60) + ] + + result = enhance_with_quantitative_scores(stocks, "2024-01-15", max_stocks=10) + + assert mock_calc.call_count == 10 + + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics" + ) + def test_enhance_handles_partial_failures(self, mock_calc): + from tradingagents.agents.discovery.quantitative_models import ( + QuantitativeMetrics, + ) + from tradingagents.agents.discovery.quantitative_scorer import ( + enhance_with_quantitative_scores, + ) + + def side_effect(ticker, date): + if ticker == "FAIL": + return None + return QuantitativeMetrics( + momentum_score=0.5, + volume_score=0.5, + relative_strength_score=0.5, + risk_reward_score=0.5, + quantitative_score=0.5, + ) + + mock_calc.side_effect = side_effect + + stocks = [ + self._create_mock_trending_stock("AAPL", 90.0), + self._create_mock_trending_stock("FAIL", 85.0), + self._create_mock_trending_stock("MSFT", 80.0), + ] + + result = enhance_with_quantitative_scores(stocks, "2024-01-15") + + assert len(result) == 3 + + successful = [s for s in result if s.quantitative_metrics is not None] + assert len(successful) == 2 + + @patch( + "tradingagents.agents.discovery.quantitative_scorer.calculate_single_stock_metrics" + ) + def test_enhance_sorts_by_conviction_score(self, mock_calc): + from tradingagents.agents.discovery.quantitative_models import ( + QuantitativeMetrics, + ) + from tradingagents.agents.discovery.quantitative_scorer import ( + enhance_with_quantitative_scores, + ) + + def side_effect(ticker, date): + scores = { + "LOW": 0.3, + "MED": 0.5, + "HIGH": 0.9, + } + quant_score = scores.get(ticker, 0.5) + return QuantitativeMetrics( + momentum_score=quant_score, + volume_score=quant_score, + relative_strength_score=quant_score, + risk_reward_score=quant_score, + quantitative_score=quant_score, + ) + + mock_calc.side_effect = side_effect + + stocks = [ + self._create_mock_trending_stock("LOW", 60.0), + self._create_mock_trending_stock("HIGH", 40.0), + self._create_mock_trending_stock("MED", 50.0), + ] + + result = enhance_with_quantitative_scores(stocks, "2024-01-15") + + assert ( + result[0].quantitative_metrics.quantitative_score + >= result[1].quantitative_metrics.quantitative_score + ) + assert ( + result[1].quantitative_metrics.quantitative_score + >= result[2].quantitative_metrics.quantitative_score + ) + + +class TestErrorHandling: + def test_empty_stock_list(self): + from tradingagents.agents.discovery.quantitative_scorer import ( + enhance_with_quantitative_scores, + ) + + result = enhance_with_quantitative_scores([], "2024-01-15") + + assert result == [] diff --git a/tests/discovery/test_relative_strength.py b/tests/discovery/test_relative_strength.py new file mode 100644 index 00000000..a4331961 --- /dev/null +++ b/tests/discovery/test_relative_strength.py @@ -0,0 +1,190 @@ +from unittest.mock import MagicMock, patch + +import pytest + + +class TestCalculateReturn: + def test_calculate_positive_return(self): + from tradingagents.agents.discovery.indicators.relative_strength import ( + calculate_return, + ) + + prices = [100.0, 102.0, 105.0, 108.0, 110.0] + ret = calculate_return(prices, days=5) + assert ret == pytest.approx(10.0, rel=0.01) + + def test_calculate_negative_return(self): + from tradingagents.agents.discovery.indicators.relative_strength import ( + calculate_return, + ) + + prices = [100.0, 98.0, 95.0, 92.0, 90.0] + ret = calculate_return(prices, days=5) + assert ret == pytest.approx(-10.0, rel=0.01) + + def test_calculate_return_insufficient_data(self): + from tradingagents.agents.discovery.indicators.relative_strength import ( + calculate_return, + ) + + prices = [100.0, 110.0] + ret = calculate_return(prices, days=5) + assert ret == pytest.approx(10.0, rel=0.01) + + def test_calculate_return_empty_list(self): + from tradingagents.agents.discovery.indicators.relative_strength import ( + calculate_return, + ) + + prices = [] + ret = calculate_return(prices, days=5) + assert ret == 0.0 + + +class TestCalculateRelativeStrength: + def test_positive_outperformance(self): + from tradingagents.agents.discovery.indicators.relative_strength import ( + calculate_relative_strength, + ) + + rs = calculate_relative_strength(stock_return=15.0, benchmark_return=10.0) + assert rs == 5.0 + + def test_negative_outperformance(self): + from tradingagents.agents.discovery.indicators.relative_strength import ( + calculate_relative_strength, + ) + + rs = calculate_relative_strength(stock_return=5.0, benchmark_return=10.0) + assert rs == -5.0 + + def test_equal_returns(self): + from tradingagents.agents.discovery.indicators.relative_strength import ( + calculate_relative_strength, + ) + + rs = calculate_relative_strength(stock_return=10.0, benchmark_return=10.0) + assert rs == 0.0 + + +class TestSectorEtfMap: + def test_sector_etf_map_contains_expected_sectors(self): + from tradingagents.agents.discovery.indicators.relative_strength import ( + SECTOR_ETF_MAP, + ) + + assert "technology" in SECTOR_ETF_MAP + assert "finance" in SECTOR_ETF_MAP + assert "healthcare" in SECTOR_ETF_MAP + assert "energy" in SECTOR_ETF_MAP + assert "consumer_goods" in SECTOR_ETF_MAP + assert "industrials" in SECTOR_ETF_MAP + assert "other" in SECTOR_ETF_MAP + + def test_sector_etf_map_values(self): + from tradingagents.agents.discovery.indicators.relative_strength import ( + SECTOR_ETF_MAP, + ) + + assert SECTOR_ETF_MAP["technology"] == "XLK" + assert SECTOR_ETF_MAP["finance"] == "XLF" + assert SECTOR_ETF_MAP["healthcare"] == "XLV" + assert SECTOR_ETF_MAP["energy"] == "XLE" + assert SECTOR_ETF_MAP["other"] == "SPY" + + +class TestGetSectorEtf: + @patch( + "tradingagents.agents.discovery.indicators.relative_strength.classify_sector" + ) + def test_get_sector_etf_for_tech_stock(self, mock_classify): + from tradingagents.agents.discovery.indicators.relative_strength import ( + get_sector_etf, + ) + + mock_classify.return_value = "technology" + etf = get_sector_etf("AAPL") + assert etf == "XLK" + + @patch( + "tradingagents.agents.discovery.indicators.relative_strength.classify_sector" + ) + def test_get_sector_etf_for_finance_stock(self, mock_classify): + from tradingagents.agents.discovery.indicators.relative_strength import ( + get_sector_etf, + ) + + mock_classify.return_value = "finance" + etf = get_sector_etf("JPM") + assert etf == "XLF" + + @patch( + "tradingagents.agents.discovery.indicators.relative_strength.classify_sector" + ) + def test_get_sector_etf_for_unknown_sector(self, mock_classify): + from tradingagents.agents.discovery.indicators.relative_strength import ( + get_sector_etf, + ) + + mock_classify.return_value = "other" + etf = get_sector_etf("XYZ") + assert etf == "SPY" + + +class TestCalculateRelativeStrengthMetrics: + @patch( + "tradingagents.agents.discovery.indicators.relative_strength._get_price_history" + ) + @patch("tradingagents.agents.discovery.indicators.relative_strength.get_sector_etf") + def test_calculate_rs_metrics_returns_dict( + self, mock_sector_etf, mock_price_history + ): + from tradingagents.agents.discovery.indicators.relative_strength import ( + calculate_relative_strength_metrics, + ) + + mock_sector_etf.return_value = "XLK" + + base_prices = [100.0 + i * 0.5 for i in range(70)] + spy_prices = [100.0 + i * 0.3 for i in range(70)] + sector_prices = [100.0 + i * 0.4 for i in range(70)] + + def price_history_side_effect(ticker, date, days): + if ticker == "AAPL": + return base_prices[-days:] + elif ticker == "SPY": + return spy_prices[-days:] + else: + return sector_prices[-days:] + + mock_price_history.side_effect = price_history_side_effect + + result = calculate_relative_strength_metrics("AAPL", "2024-01-15") + + assert isinstance(result, dict) + assert "rs_vs_spy_5d" in result + assert "rs_vs_spy_20d" in result + assert "rs_vs_spy_60d" in result + assert "rs_vs_sector" in result + assert "sector_etf" in result + assert "relative_strength_score" in result + assert 0.0 <= result["relative_strength_score"] <= 1.0 + + @patch( + "tradingagents.agents.discovery.indicators.relative_strength._get_price_history" + ) + @patch("tradingagents.agents.discovery.indicators.relative_strength.get_sector_etf") + def test_calculate_rs_metrics_handles_missing_benchmark_data( + self, mock_sector_etf, mock_price_history + ): + from tradingagents.agents.discovery.indicators.relative_strength import ( + calculate_relative_strength_metrics, + ) + + mock_sector_etf.return_value = "XLK" + mock_price_history.return_value = [] + + result = calculate_relative_strength_metrics("INVALID", "2024-01-15") + + assert isinstance(result, dict) + assert result["relative_strength_score"] == 0.5 diff --git a/tests/discovery/test_risk_reward.py b/tests/discovery/test_risk_reward.py new file mode 100644 index 00000000..70eb3d90 --- /dev/null +++ b/tests/discovery/test_risk_reward.py @@ -0,0 +1,121 @@ +import pytest + + +class TestCalculateStopLoss: + def test_calculate_stop_loss_default_multiplier(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_stop_loss, + ) + + stop = calculate_stop_loss(price=100.0, atr=2.0, multiplier=1.5) + assert stop == 97.0 + + def test_calculate_stop_loss_custom_multiplier(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_stop_loss, + ) + + stop = calculate_stop_loss(price=100.0, atr=2.0, multiplier=2.0) + assert stop == 96.0 + + def test_calculate_stop_loss_large_atr(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_stop_loss, + ) + + stop = calculate_stop_loss(price=100.0, atr=10.0, multiplier=1.5) + assert stop == 85.0 + + +class TestCalculateRewardTarget: + def test_calculate_reward_target(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_reward_target, + ) + + target = calculate_reward_target(price=100.0, resistance=120.0) + assert target == 120.0 + + def test_calculate_reward_target_resistance_below_price(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_reward_target, + ) + + target = calculate_reward_target(price=100.0, resistance=90.0) + assert target == 90.0 + + +class TestCalculateRiskRewardRatio: + def test_calculate_rr_ratio_good_trade(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_risk_reward_ratio, + ) + + rr = calculate_risk_reward_ratio(price=100.0, stop=95.0, target=115.0) + assert rr == 3.0 + + def test_calculate_rr_ratio_poor_trade(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_risk_reward_ratio, + ) + + rr = calculate_risk_reward_ratio(price=100.0, stop=95.0, target=102.0) + assert rr == pytest.approx(0.4, rel=0.01) + + def test_calculate_rr_ratio_stop_at_price_returns_zero(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_risk_reward_ratio, + ) + + rr = calculate_risk_reward_ratio(price=100.0, stop=100.0, target=110.0) + assert rr == 0.0 + + def test_calculate_rr_ratio_target_below_price(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_risk_reward_ratio, + ) + + rr = calculate_risk_reward_ratio(price=100.0, stop=95.0, target=98.0) + assert rr < 0 + + +class TestCalculateRiskRewardScore: + def test_excellent_rr_ratio_high_score(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_risk_reward_score, + ) + + score = calculate_risk_reward_score(rr_ratio=3.5) + assert 0.9 <= score <= 1.0 + + def test_good_rr_ratio_moderate_high_score(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_risk_reward_score, + ) + + score = calculate_risk_reward_score(rr_ratio=2.5) + assert 0.7 <= score <= 0.9 + + def test_acceptable_rr_ratio_moderate_score(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_risk_reward_score, + ) + + score = calculate_risk_reward_score(rr_ratio=1.5) + assert 0.4 <= score <= 0.7 + + def test_poor_rr_ratio_low_score(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_risk_reward_score, + ) + + score = calculate_risk_reward_score(rr_ratio=0.5) + assert 0.0 <= score <= 0.4 + + def test_negative_rr_ratio_very_low_score(self): + from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_risk_reward_score, + ) + + score = calculate_risk_reward_score(rr_ratio=-1.0) + assert score == 0.0 diff --git a/tests/discovery/test_support_resistance.py b/tests/discovery/test_support_resistance.py new file mode 100644 index 00000000..dd0fcebf --- /dev/null +++ b/tests/discovery/test_support_resistance.py @@ -0,0 +1,224 @@ +from unittest.mock import MagicMock, patch + +import pytest + + +class TestFindSupportLevels: + def test_find_support_from_lows(self): + from tradingagents.agents.discovery.indicators.support_resistance import ( + find_support_levels, + ) + + lows = ( + [100.0, 98.0, 97.0, 99.0, 96.0, 95.0, 97.0, 98.0, 94.0, 93.0] + + [95.0, 96.0, 94.0, 93.0, 92.0, 91.0, 90.0, 89.0, 88.0, 87.0] + + [88.0, 89.0, 87.0, 86.0, 85.0] * 6 + ) + + support_20d, support_50d = find_support_levels( + lows, lookback_20=20, lookback_50=50 + ) + + assert support_20d <= 100.0 + assert support_50d <= 100.0 + + def test_find_support_empty_list(self): + from tradingagents.agents.discovery.indicators.support_resistance import ( + find_support_levels, + ) + + lows = [] + support_20d, support_50d = find_support_levels( + lows, lookback_20=20, lookback_50=50 + ) + + assert support_20d == 0.0 + assert support_50d == 0.0 + + +class TestFindResistanceLevels: + def test_find_resistance_from_highs(self): + from tradingagents.agents.discovery.indicators.support_resistance import ( + find_resistance_levels, + ) + + highs = ( + [100.0, 102.0, 105.0, 103.0, 108.0, 106.0, 110.0, 107.0, 112.0, 109.0] + + [111.0, 113.0, 115.0, 114.0, 117.0, 116.0, 120.0, 118.0, 122.0, 119.0] + + [120.0, 121.0, 123.0, 124.0, 125.0] * 6 + ) + + resistance_20d, resistance_50d = find_resistance_levels( + highs, lookback_20=20, lookback_50=50 + ) + + assert resistance_20d >= 100.0 + assert resistance_50d >= 100.0 + + def test_find_resistance_empty_list(self): + from tradingagents.agents.discovery.indicators.support_resistance import ( + find_resistance_levels, + ) + + highs = [] + resistance_20d, resistance_50d = find_resistance_levels( + highs, lookback_20=20, lookback_50=50 + ) + + assert resistance_20d == 0.0 + assert resistance_50d == 0.0 + + +class TestDetectSwingPoints: + def test_detect_swing_highs_and_lows(self): + from tradingagents.agents.discovery.indicators.support_resistance import ( + detect_swing_points, + ) + + prices = [ + 100, + 102, + 105, + 103, + 98, + 95, + 97, + 100, + 103, + 107, + 105, + 102, + 99, + 96, + 94, + 97, + 101, + ] + swing_lows, swing_highs = detect_swing_points(prices, n_bars=3) + + assert len(swing_lows) >= 0 + assert len(swing_highs) >= 0 + + def test_detect_swing_points_short_list(self): + from tradingagents.agents.discovery.indicators.support_resistance import ( + detect_swing_points, + ) + + prices = [100, 105, 102] + swing_lows, swing_highs = detect_swing_points(prices, n_bars=3) + + assert isinstance(swing_lows, list) + assert isinstance(swing_highs, list) + + def test_detect_swing_points_empty_list(self): + from tradingagents.agents.discovery.indicators.support_resistance import ( + detect_swing_points, + ) + + prices = [] + swing_lows, swing_highs = detect_swing_points(prices, n_bars=5) + + assert swing_lows == [] + assert swing_highs == [] + + +class TestGetNearestLevels: + def test_get_nearest_support_and_resistance(self): + from tradingagents.agents.discovery.indicators.support_resistance import ( + get_nearest_levels, + ) + + price = 150.0 + supports = [140.0, 145.0, 130.0, 120.0] + resistances = [155.0, 160.0, 170.0, 180.0] + + nearest_support, nearest_resistance = get_nearest_levels( + price, supports, resistances + ) + + assert nearest_support == 145.0 + assert nearest_resistance == 155.0 + + def test_get_nearest_levels_no_support_below(self): + from tradingagents.agents.discovery.indicators.support_resistance import ( + get_nearest_levels, + ) + + price = 100.0 + supports = [110.0, 120.0, 130.0] + resistances = [150.0, 160.0] + + nearest_support, nearest_resistance = get_nearest_levels( + price, supports, resistances + ) + + assert nearest_support == 0.0 + assert nearest_resistance == 150.0 + + def test_get_nearest_levels_empty_lists(self): + from tradingagents.agents.discovery.indicators.support_resistance import ( + get_nearest_levels, + ) + + price = 150.0 + supports = [] + resistances = [] + + nearest_support, nearest_resistance = get_nearest_levels( + price, supports, resistances + ) + + assert nearest_support == 0.0 + assert nearest_resistance == 0.0 + + +class TestCalculateSupportResistanceMetrics: + @patch( + "tradingagents.agents.discovery.indicators.support_resistance._get_ohlc_data" + ) + @patch("tradingagents.agents.discovery.indicators.support_resistance._get_atr") + def test_calculate_sr_metrics_returns_dict(self, mock_atr, mock_ohlc): + from tradingagents.agents.discovery.indicators.support_resistance import ( + calculate_support_resistance_metrics, + ) + + mock_ohlc.return_value = { + "highs": [105.0 + i * 0.5 for i in range(60)], + "lows": [95.0 + i * 0.3 for i in range(60)], + "closes": [100.0 + i * 0.4 for i in range(60)], + "current_price": 125.0, + } + mock_atr.return_value = 2.5 + + result = calculate_support_resistance_metrics("AAPL", "2024-01-15") + + assert isinstance(result, dict) + assert "support_level" in result + assert "resistance_level" in result + assert "atr" in result + assert "suggested_stop" in result + assert "reward_target" in result + assert "risk_reward_ratio" in result + assert "risk_reward_score" in result + + @patch( + "tradingagents.agents.discovery.indicators.support_resistance._get_ohlc_data" + ) + @patch("tradingagents.agents.discovery.indicators.support_resistance._get_atr") + def test_calculate_sr_metrics_handles_missing_data(self, mock_atr, mock_ohlc): + from tradingagents.agents.discovery.indicators.support_resistance import ( + calculate_support_resistance_metrics, + ) + + mock_ohlc.return_value = { + "highs": [], + "lows": [], + "closes": [], + "current_price": None, + } + mock_atr.return_value = None + + result = calculate_support_resistance_metrics("INVALID", "2024-01-15") + + assert isinstance(result, dict) + assert result["risk_reward_score"] == 0.5 diff --git a/tests/discovery/test_volume.py b/tests/discovery/test_volume.py new file mode 100644 index 00000000..58728366 --- /dev/null +++ b/tests/discovery/test_volume.py @@ -0,0 +1,201 @@ +from unittest.mock import MagicMock, patch + +import pytest + + +class TestCalculateVolumeRatio: + def test_volume_ratio_above_average(self): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_volume_ratio, + ) + + ratio = calculate_volume_ratio( + current_volume=2_000_000, avg_volume_20d=1_000_000 + ) + assert ratio == 2.0 + + def test_volume_ratio_below_average(self): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_volume_ratio, + ) + + ratio = calculate_volume_ratio(current_volume=500_000, avg_volume_20d=1_000_000) + assert ratio == 0.5 + + def test_volume_ratio_zero_average_returns_zero(self): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_volume_ratio, + ) + + ratio = calculate_volume_ratio(current_volume=1_000_000, avg_volume_20d=0) + assert ratio == 0.0 + + +class TestCalculateVolumeTrend: + def test_volume_trend_increasing(self): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_volume_trend, + ) + + volume_series = [100_000, 150_000, 200_000, 250_000, 300_000] + trend, slope = calculate_volume_trend(volume_series) + assert trend == "increasing" + assert slope > 0 + + def test_volume_trend_decreasing(self): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_volume_trend, + ) + + volume_series = [300_000, 250_000, 200_000, 150_000, 100_000] + trend, slope = calculate_volume_trend(volume_series) + assert trend == "decreasing" + assert slope < 0 + + def test_volume_trend_flat(self): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_volume_trend, + ) + + volume_series = [100_000, 100_100, 99_900, 100_050, 99_950] + trend, slope = calculate_volume_trend(volume_series) + assert trend == "flat" + + def test_volume_trend_empty_returns_flat(self): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_volume_trend, + ) + + volume_series = [] + trend, slope = calculate_volume_trend(volume_series) + assert trend == "flat" + assert slope == 0.0 + + +class TestCalculateDollarVolume: + def test_dollar_volume_calculation(self): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_dollar_volume, + ) + + dollar_vol = calculate_dollar_volume(price=150.0, volume=1_000_000) + assert dollar_vol == 150_000_000.0 + + def test_dollar_volume_zero_price(self): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_dollar_volume, + ) + + dollar_vol = calculate_dollar_volume(price=0.0, volume=1_000_000) + assert dollar_vol == 0.0 + + +class TestCalculateVolumeScore: + def test_volume_spike_with_positive_price_high_score(self): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_volume_score, + ) + + score = calculate_volume_score( + volume_ratio=2.5, + trend="increasing", + dollar_volume=50_000_000.0, + price_change=5.0, + min_dollar_volume=1_000_000.0, + ) + assert 0.8 <= score <= 1.0 + + def test_above_average_volume_moderate_score(self): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_volume_score, + ) + + score = calculate_volume_score( + volume_ratio=1.5, + trend="increasing", + dollar_volume=10_000_000.0, + price_change=2.0, + min_dollar_volume=1_000_000.0, + ) + assert 0.7 <= score <= 0.9 + + def test_normal_volume_neutral_score(self): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_volume_score, + ) + + score = calculate_volume_score( + volume_ratio=1.0, + trend="flat", + dollar_volume=5_000_000.0, + price_change=0.5, + min_dollar_volume=1_000_000.0, + ) + assert 0.3 <= score <= 0.6 + + def test_low_dollar_volume_penalized(self): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_volume_score, + ) + + high_volume_score = calculate_volume_score( + volume_ratio=2.0, + trend="increasing", + dollar_volume=50_000_000.0, + price_change=3.0, + min_dollar_volume=1_000_000.0, + ) + + low_volume_score = calculate_volume_score( + volume_ratio=2.0, + trend="increasing", + dollar_volume=500_000.0, + price_change=3.0, + min_dollar_volume=1_000_000.0, + ) + + assert low_volume_score < high_volume_score + + +class TestCalculateVolumeMetrics: + @patch("tradingagents.agents.discovery.indicators.volume._get_volume_price_data") + def test_calculate_volume_metrics_returns_dict(self, mock_get_data): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_volume_metrics, + ) + + mock_get_data.return_value = { + "volumes": [1_000_000] * 25, + "prices": [100.0] * 25, + "current_price": 105.0, + "current_volume": 1_500_000, + "price_change_pct": 5.0, + } + + result = calculate_volume_metrics("AAPL", "2024-01-15") + + assert isinstance(result, dict) + assert "volume_ratio" in result + assert "volume_trend" in result + assert "dollar_volume" in result + assert "volume_score" in result + assert 0.0 <= result["volume_score"] <= 1.0 + + @patch("tradingagents.agents.discovery.indicators.volume._get_volume_price_data") + def test_calculate_volume_metrics_handles_missing_data(self, mock_get_data): + from tradingagents.agents.discovery.indicators.volume import ( + calculate_volume_metrics, + ) + + mock_get_data.return_value = { + "volumes": [], + "prices": [], + "current_price": None, + "current_volume": None, + "price_change_pct": 0.0, + } + + result = calculate_volume_metrics("INVALID", "2024-01-15") + + assert isinstance(result, dict) + assert result["volume_score"] == 0.5 diff --git a/tradingagents/agents/discovery/__init__.py b/tradingagents/agents/discovery/__init__.py index 30cfc436..2a35d118 100644 --- a/tradingagents/agents/discovery/__init__.py +++ b/tradingagents/agents/discovery/__init__.py @@ -22,6 +22,7 @@ from .persistence import ( generate_markdown_summary, save_discovery_result, ) +from .quantitative_models import QuantitativeMetrics from .scorer import ( DEFAULT_DECAY_RATE, DEFAULT_MAX_RESULTS, @@ -50,4 +51,5 @@ __all__ = [ "DEFAULT_MIN_MENTIONS", "save_discovery_result", "generate_markdown_summary", + "QuantitativeMetrics", ] diff --git a/tradingagents/agents/discovery/indicators/__init__.py b/tradingagents/agents/discovery/indicators/__init__.py new file mode 100644 index 00000000..8aac671c --- /dev/null +++ b/tradingagents/agents/discovery/indicators/__init__.py @@ -0,0 +1,61 @@ +from tradingagents.agents.discovery.indicators.momentum import ( + calculate_ema_direction, + calculate_macd_score, + calculate_momentum_score, + calculate_rsi_score, + calculate_sma_score, +) +from tradingagents.agents.discovery.indicators.relative_strength import ( + SECTOR_ETF_MAP, + calculate_relative_strength, + calculate_relative_strength_metrics, + calculate_return, + get_sector_etf, +) +from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_reward_target, + calculate_risk_reward_ratio, + calculate_risk_reward_score, + calculate_stop_loss, +) +from tradingagents.agents.discovery.indicators.support_resistance import ( + calculate_support_resistance_metrics, + detect_swing_points, + find_resistance_levels, + find_support_levels, + get_nearest_levels, +) +from tradingagents.agents.discovery.indicators.volume import ( + calculate_dollar_volume, + calculate_volume_metrics, + calculate_volume_ratio, + calculate_volume_score, + calculate_volume_trend, +) + +__all__ = [ + "calculate_rsi_score", + "calculate_macd_score", + "calculate_sma_score", + "calculate_ema_direction", + "calculate_momentum_score", + "calculate_volume_ratio", + "calculate_volume_trend", + "calculate_dollar_volume", + "calculate_volume_score", + "calculate_volume_metrics", + "SECTOR_ETF_MAP", + "calculate_return", + "calculate_relative_strength", + "get_sector_etf", + "calculate_relative_strength_metrics", + "find_support_levels", + "find_resistance_levels", + "detect_swing_points", + "get_nearest_levels", + "calculate_support_resistance_metrics", + "calculate_stop_loss", + "calculate_reward_target", + "calculate_risk_reward_ratio", + "calculate_risk_reward_score", +] diff --git a/tradingagents/agents/discovery/indicators/momentum.py b/tradingagents/agents/discovery/indicators/momentum.py new file mode 100644 index 00000000..a57ca914 --- /dev/null +++ b/tradingagents/agents/discovery/indicators/momentum.py @@ -0,0 +1,349 @@ +import logging +from datetime import datetime + +import pandas as pd + +logger = logging.getLogger(__name__) + + +def calculate_rsi_score(rsi: float) -> float: + if rsi <= 20: + return 1.0 + elif rsi <= 30: + return 0.8 + (30 - rsi) / 50 + elif rsi <= 35: + return 0.7 + (35 - rsi) / 50 + elif rsi <= 50: + return 0.6 + (50 - rsi) / 150 + elif rsi <= 65: + return 0.5 + (65 - rsi) / 150 + elif rsi <= 70: + return 0.4 + (70 - rsi) / 50 + elif rsi <= 80: + return 0.2 + (80 - rsi) / 50 + else: + return max(0.0, 0.2 - (rsi - 80) / 100) + + +def calculate_macd_score(macd: float, signal: float, histogram: float) -> float: + score = 0.5 + + if macd > signal: + crossover_strength = min((macd - signal) / max(abs(signal), 0.01), 1.0) + score += 0.25 * crossover_strength + else: + crossover_weakness = min((signal - macd) / max(abs(signal), 0.01), 1.0) + score -= 0.25 * crossover_weakness + + if histogram > 0: + histogram_bonus = min(histogram / max(abs(macd), 0.01), 1.0) * 0.25 + score += histogram_bonus + else: + histogram_penalty = min(abs(histogram) / max(abs(macd), 0.01), 1.0) * 0.25 + score -= histogram_penalty + + return max(0.0, min(1.0, score)) + + +def calculate_sma_score(price: float, sma50: float, sma200: float) -> float: + if sma50 == 0 or sma200 == 0: + return 0.5 + + score = 0.5 + + pct_vs_sma50 = (price - sma50) / sma50 + pct_vs_sma200 = (price - sma200) / sma200 + + if pct_vs_sma50 > 0: + score += min(pct_vs_sma50 * 2, 0.25) + else: + score += max(pct_vs_sma50 * 2, -0.25) + + if pct_vs_sma200 > 0: + score += min(pct_vs_sma200 * 2, 0.25) + else: + score += max(pct_vs_sma200 * 2, -0.25) + + if price > sma50 > sma200: + score += 0.15 + + return max(0.0, min(1.0, score)) + + +def calculate_ema_direction(ema_values: list[float]) -> str: + if len(ema_values) < 2: + return "flat" + + first_half = ema_values[: len(ema_values) // 2] + second_half = ema_values[len(ema_values) // 2 :] + + if not first_half or not second_half: + return "flat" + + first_avg = sum(first_half) / len(first_half) + second_avg = sum(second_half) / len(second_half) + + pct_change = (second_avg - first_avg) / first_avg if first_avg != 0 else 0 + + if pct_change > 0.01: + return "up" + elif pct_change < -0.01: + return "down" + else: + return "flat" + + +def _get_stock_stats_bulk(symbol: str, indicator: str, curr_date: str) -> dict: + import os + + from stockstats import wrap + + from tradingagents.dataflows.config import get_config + + config = get_config() + online = config["data_vendors"]["technical_indicators"] != "local" + + if not online: + try: + data = pd.read_csv( + os.path.join( + config.get("data_cache_dir", "data"), + f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv", + ) + ) + df = wrap(data) + except FileNotFoundError: + raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!") + else: + import yfinance as yf + + today_date = pd.Timestamp.today() + + end_date = today_date + start_date = today_date - pd.DateOffset(years=15) + start_date_str = start_date.strftime("%Y-%m-%d") + end_date_str = end_date.strftime("%Y-%m-%d") + + os.makedirs(config["data_cache_dir"], exist_ok=True) + + data_file = os.path.join( + config["data_cache_dir"], + f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv", + ) + + if os.path.exists(data_file): + data = pd.read_csv(data_file) + data["Date"] = pd.to_datetime(data["Date"]) + else: + data = yf.download( + symbol, + start=start_date_str, + end=end_date_str, + multi_level_index=False, + progress=False, + auto_adjust=True, + ) + data = data.reset_index() + data.to_csv(data_file, index=False) + + df = wrap(data) + df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") + + df[indicator] + + indicator_series = df[indicator].apply(lambda x: "N/A" if pd.isna(x) else str(x)) + result_dict = dict(zip(df["Date"], indicator_series, strict=False)) + + return result_dict + + +def _get_price_data(symbol: str, curr_date: str) -> dict: + import os + + from tradingagents.dataflows.config import get_config + + config = get_config() + online = config["data_vendors"]["technical_indicators"] != "local" + + if not online: + try: + data = pd.read_csv( + os.path.join( + config.get("data_cache_dir", "data"), + f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv", + ) + ) + except FileNotFoundError: + return {} + else: + import yfinance as yf + + today_date = pd.Timestamp.today() + + end_date = today_date + start_date = today_date - pd.DateOffset(years=15) + start_date_str = start_date.strftime("%Y-%m-%d") + end_date_str = end_date.strftime("%Y-%m-%d") + + os.makedirs(config["data_cache_dir"], exist_ok=True) + + data_file = os.path.join( + config["data_cache_dir"], + f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv", + ) + + if os.path.exists(data_file): + data = pd.read_csv(data_file) + else: + data = yf.download( + symbol, + start=start_date_str, + end=end_date_str, + multi_level_index=False, + progress=False, + auto_adjust=True, + ) + data = data.reset_index() + data.to_csv(data_file, index=False) + + data["Date"] = pd.to_datetime(data["Date"]).dt.strftime("%Y-%m-%d") + result_dict = {} + for _, row in data.iterrows(): + result_dict[row["Date"]] = row["Close"] + + return result_dict + + +def calculate_momentum_score(ticker: str, curr_date: str) -> dict: + result = { + "rsi": None, + "macd": None, + "macd_signal": None, + "macd_histogram": None, + "price_vs_sma50": None, + "price_vs_sma200": None, + "ema10_direction": None, + "momentum_score": 0.5, + } + + try: + rsi_data = _get_stock_stats_bulk(ticker, "rsi", curr_date) + macd_data = _get_stock_stats_bulk(ticker, "macd", curr_date) + macds_data = _get_stock_stats_bulk(ticker, "macds", curr_date) + macdh_data = _get_stock_stats_bulk(ticker, "macdh", curr_date) + sma50_data = _get_stock_stats_bulk(ticker, "close_50_sma", curr_date) + sma200_data = _get_stock_stats_bulk(ticker, "close_200_sma", curr_date) + ema10_data = _get_stock_stats_bulk(ticker, "close_10_ema", curr_date) + price_data = _get_price_data(ticker, curr_date) + + rsi_value = None + if curr_date in rsi_data and rsi_data[curr_date] != "N/A": + try: + rsi_value = float(rsi_data[curr_date]) + result["rsi"] = rsi_value + except (ValueError, TypeError): + pass + + macd_value = None + macds_value = None + macdh_value = None + + if curr_date in macd_data and macd_data[curr_date] != "N/A": + try: + macd_value = float(macd_data[curr_date]) + result["macd"] = macd_value + except (ValueError, TypeError): + pass + + if curr_date in macds_data and macds_data[curr_date] != "N/A": + try: + macds_value = float(macds_data[curr_date]) + result["macd_signal"] = macds_value + except (ValueError, TypeError): + pass + + if curr_date in macdh_data and macdh_data[curr_date] != "N/A": + try: + macdh_value = float(macdh_data[curr_date]) + result["macd_histogram"] = macdh_value + except (ValueError, TypeError): + pass + + current_price = None + sma50_value = None + sma200_value = None + + if curr_date in price_data: + try: + current_price = float(price_data[curr_date]) + except (ValueError, TypeError): + pass + + if curr_date in sma50_data and sma50_data[curr_date] != "N/A": + try: + sma50_value = float(sma50_data[curr_date]) + except (ValueError, TypeError): + pass + + if curr_date in sma200_data and sma200_data[curr_date] != "N/A": + try: + sma200_value = float(sma200_data[curr_date]) + except (ValueError, TypeError): + pass + + if current_price and sma50_value: + result["price_vs_sma50"] = ( + (current_price - sma50_value) / sma50_value + ) * 100 + + if current_price and sma200_value: + result["price_vs_sma200"] = ( + (current_price - sma200_value) / sma200_value + ) * 100 + + ema_values = [] + sorted_dates = sorted( + [d for d in ema10_data.keys() if d <= curr_date], reverse=True + )[:10] + for date in sorted_dates: + if ema10_data[date] != "N/A": + try: + ema_values.append(float(ema10_data[date])) + except (ValueError, TypeError): + pass + + if ema_values: + result["ema10_direction"] = calculate_ema_direction(ema_values[::-1]) + + scores = [] + weights = [] + + if rsi_value is not None: + scores.append(calculate_rsi_score(rsi_value)) + weights.append(0.25) + + if ( + macd_value is not None + and macds_value is not None + and macdh_value is not None + ): + scores.append(calculate_macd_score(macd_value, macds_value, macdh_value)) + weights.append(0.35) + + if current_price and sma50_value and sma200_value: + scores.append(calculate_sma_score(current_price, sma50_value, sma200_value)) + weights.append(0.40) + + if scores and weights: + total_weight = sum(weights) + result["momentum_score"] = ( + sum(s * w for s, w in zip(scores, weights, strict=False)) / total_weight + ) + else: + result["momentum_score"] = 0.5 + + except (KeyError, ValueError, FileNotFoundError, RuntimeError) as e: + logger.warning("Failed to calculate momentum score for %s: %s", ticker, str(e)) + result["momentum_score"] = 0.5 + + return result diff --git a/tradingagents/agents/discovery/indicators/relative_strength.py b/tradingagents/agents/discovery/indicators/relative_strength.py new file mode 100644 index 00000000..7d37d0ad --- /dev/null +++ b/tradingagents/agents/discovery/indicators/relative_strength.py @@ -0,0 +1,181 @@ +import logging +import os + +import pandas as pd + +from tradingagents.dataflows.trending.sector_classifier import classify_sector + +logger = logging.getLogger(__name__) + +SECTOR_ETF_MAP = { + "technology": "XLK", + "finance": "XLF", + "healthcare": "XLV", + "energy": "XLE", + "consumer_goods": "XLY", + "industrials": "XLI", + "other": "SPY", +} + + +def calculate_return(prices: list[float], days: int) -> float: + if len(prices) < 2: + return 0.0 + + start_idx = max(0, len(prices) - days) + start_price = prices[start_idx] + end_price = prices[-1] + + if start_price == 0: + return 0.0 + + return ((end_price - start_price) / start_price) * 100 + + +def calculate_relative_strength(stock_return: float, benchmark_return: float) -> float: + return stock_return - benchmark_return + + +def get_sector_etf(ticker: str) -> str: + sector = classify_sector(ticker) + return SECTOR_ETF_MAP.get(sector, "SPY") + + +def _get_price_history(ticker: str, curr_date: str, days: int) -> list[float]: + from tradingagents.dataflows.config import get_config + + config = get_config() + online = config["data_vendors"]["technical_indicators"] != "local" + + try: + if not online: + data = pd.read_csv( + os.path.join( + config.get("data_cache_dir", "data"), + f"{ticker}-YFin-data-2015-01-01-2025-03-25.csv", + ) + ) + else: + import yfinance as yf + + today_date = pd.Timestamp.today() + + end_date = today_date + start_date = today_date - pd.DateOffset(years=15) + start_date_str = start_date.strftime("%Y-%m-%d") + end_date_str = end_date.strftime("%Y-%m-%d") + + os.makedirs(config["data_cache_dir"], exist_ok=True) + + data_file = os.path.join( + config["data_cache_dir"], + f"{ticker}-YFin-data-{start_date_str}-{end_date_str}.csv", + ) + + if os.path.exists(data_file): + data = pd.read_csv(data_file) + else: + data = yf.download( + ticker, + start=start_date_str, + end=end_date_str, + multi_level_index=False, + progress=False, + auto_adjust=True, + ) + data = data.reset_index() + data.to_csv(data_file, index=False) + + data["Date"] = pd.to_datetime(data["Date"]).dt.strftime("%Y-%m-%d") + data = data[data["Date"] <= curr_date].tail(days + 5) + + return data["Close"].tolist() + + except (FileNotFoundError, KeyError, ValueError) as e: + logger.warning("Failed to get price history for %s: %s", ticker, str(e)) + return [] + + +def _calculate_rs_score( + rs_5d: float, rs_20d: float, rs_60d: float, rs_sector: float +) -> float: + score = 0.5 + + if rs_20d > 5: + score += min(rs_20d / 20, 0.25) + elif rs_20d < -5: + score -= min(abs(rs_20d) / 20, 0.25) + + if rs_5d > 3: + score += min(rs_5d / 15, 0.15) + elif rs_5d < -3: + score -= min(abs(rs_5d) / 15, 0.15) + + if rs_60d > 10: + score += min(rs_60d / 40, 0.1) + elif rs_60d < -10: + score -= min(abs(rs_60d) / 40, 0.1) + + if rs_sector > 3: + score += min(rs_sector / 15, 0.1) + elif rs_sector < -3: + score -= min(abs(rs_sector) / 15, 0.1) + + return max(0.0, min(1.0, score)) + + +def calculate_relative_strength_metrics(ticker: str, curr_date: str) -> dict: + result = { + "rs_vs_spy_5d": None, + "rs_vs_spy_20d": None, + "rs_vs_spy_60d": None, + "rs_vs_sector": None, + "sector_etf": None, + "relative_strength_score": 0.5, + } + + try: + stock_prices = _get_price_history(ticker, curr_date, 70) + spy_prices = _get_price_history("SPY", curr_date, 70) + + if not stock_prices or not spy_prices: + return result + + sector_etf = get_sector_etf(ticker) + result["sector_etf"] = sector_etf + + sector_prices = [] + if sector_etf != "SPY": + sector_prices = _get_price_history(sector_etf, curr_date, 70) + + stock_5d = calculate_return(stock_prices, 5) + stock_20d = calculate_return(stock_prices, 20) + stock_60d = calculate_return(stock_prices, 60) + + spy_5d = calculate_return(spy_prices, 5) + spy_20d = calculate_return(spy_prices, 20) + spy_60d = calculate_return(spy_prices, 60) + + result["rs_vs_spy_5d"] = calculate_relative_strength(stock_5d, spy_5d) + result["rs_vs_spy_20d"] = calculate_relative_strength(stock_20d, spy_20d) + result["rs_vs_spy_60d"] = calculate_relative_strength(stock_60d, spy_60d) + + if sector_prices and sector_etf != "SPY": + sector_20d = calculate_return(sector_prices, 20) + result["rs_vs_sector"] = calculate_relative_strength(stock_20d, sector_20d) + else: + result["rs_vs_sector"] = result["rs_vs_spy_20d"] + + result["relative_strength_score"] = _calculate_rs_score( + result["rs_vs_spy_5d"], + result["rs_vs_spy_20d"], + result["rs_vs_spy_60d"], + result["rs_vs_sector"], + ) + + except (KeyError, ValueError, RuntimeError) as e: + logger.warning( + "Failed to calculate relative strength for %s: %s", ticker, str(e) + ) + + return result diff --git a/tradingagents/agents/discovery/indicators/risk_reward.py b/tradingagents/agents/discovery/indicators/risk_reward.py new file mode 100644 index 00000000..2ba96443 --- /dev/null +++ b/tradingagents/agents/discovery/indicators/risk_reward.py @@ -0,0 +1,34 @@ +import logging + +logger = logging.getLogger(__name__) + + +def calculate_stop_loss(price: float, atr: float, multiplier: float = 1.5) -> float: + return price - (atr * multiplier) + + +def calculate_reward_target(price: float, resistance: float) -> float: + return resistance + + +def calculate_risk_reward_ratio(price: float, stop: float, target: float) -> float: + risk = price - stop + if risk == 0: + return 0.0 + + reward = target - price + return reward / risk + + +def calculate_risk_reward_score(rr_ratio: float) -> float: + if rr_ratio < 0: + return 0.0 + + if rr_ratio >= 3.0: + return 0.9 + min((rr_ratio - 3.0) / 10, 0.1) + elif rr_ratio >= 2.0: + return 0.7 + (rr_ratio - 2.0) / 5 + elif rr_ratio >= 1.0: + return 0.4 + (rr_ratio - 1.0) * 0.3 + else: + return rr_ratio * 0.4 diff --git a/tradingagents/agents/discovery/indicators/support_resistance.py b/tradingagents/agents/discovery/indicators/support_resistance.py new file mode 100644 index 00000000..0814d9cc --- /dev/null +++ b/tradingagents/agents/discovery/indicators/support_resistance.py @@ -0,0 +1,260 @@ +import logging +import os + +import pandas as pd + +from tradingagents.agents.discovery.indicators.risk_reward import ( + calculate_reward_target, + calculate_risk_reward_ratio, + calculate_risk_reward_score, + calculate_stop_loss, +) + +logger = logging.getLogger(__name__) + + +def find_support_levels( + lows: list[float], lookback_20: int, lookback_50: int +) -> tuple[float, float]: + if not lows: + return (0.0, 0.0) + + support_20d = ( + min(lows[-lookback_20:]) + if len(lows) >= lookback_20 + else min(lows) + if lows + else 0.0 + ) + support_50d = ( + min(lows[-lookback_50:]) + if len(lows) >= lookback_50 + else min(lows) + if lows + else 0.0 + ) + + return (support_20d, support_50d) + + +def find_resistance_levels( + highs: list[float], lookback_20: int, lookback_50: int +) -> tuple[float, float]: + if not highs: + return (0.0, 0.0) + + resistance_20d = ( + max(highs[-lookback_20:]) + if len(highs) >= lookback_20 + else max(highs) + if highs + else 0.0 + ) + resistance_50d = ( + max(highs[-lookback_50:]) + if len(highs) >= lookback_50 + else max(highs) + if highs + else 0.0 + ) + + return (resistance_20d, resistance_50d) + + +def detect_swing_points( + prices: list[float], n_bars: int = 5 +) -> tuple[list[float], list[float]]: + if len(prices) < (2 * n_bars + 1): + return ([], []) + + swing_lows = [] + swing_highs = [] + + for i in range(n_bars, len(prices) - n_bars): + is_swing_low = True + is_swing_high = True + + current = prices[i] + + for j in range(1, n_bars + 1): + if prices[i - j] <= current or prices[i + j] <= current: + is_swing_low = False + if prices[i - j] >= current or prices[i + j] >= current: + is_swing_high = False + + if is_swing_low: + swing_lows.append(current) + if is_swing_high: + swing_highs.append(current) + + return (swing_lows, swing_highs) + + +def get_nearest_levels( + price: float, supports: list[float], resistances: list[float] +) -> tuple[float, float]: + supports_below = [s for s in supports if s < price] + resistances_above = [r for r in resistances if r > price] + + nearest_support = max(supports_below) if supports_below else 0.0 + nearest_resistance = min(resistances_above) if resistances_above else 0.0 + + return (nearest_support, nearest_resistance) + + +def _get_ohlc_data(ticker: str, curr_date: str) -> dict: + from tradingagents.dataflows.config import get_config + + config = get_config() + online = config["data_vendors"]["technical_indicators"] != "local" + + result = { + "highs": [], + "lows": [], + "closes": [], + "current_price": None, + } + + try: + if not online: + data = pd.read_csv( + os.path.join( + config.get("data_cache_dir", "data"), + f"{ticker}-YFin-data-2015-01-01-2025-03-25.csv", + ) + ) + else: + import yfinance as yf + + today_date = pd.Timestamp.today() + + end_date = today_date + start_date = today_date - pd.DateOffset(years=15) + start_date_str = start_date.strftime("%Y-%m-%d") + end_date_str = end_date.strftime("%Y-%m-%d") + + os.makedirs(config["data_cache_dir"], exist_ok=True) + + data_file = os.path.join( + config["data_cache_dir"], + f"{ticker}-YFin-data-{start_date_str}-{end_date_str}.csv", + ) + + if os.path.exists(data_file): + data = pd.read_csv(data_file) + else: + data = yf.download( + ticker, + start=start_date_str, + end=end_date_str, + multi_level_index=False, + progress=False, + auto_adjust=True, + ) + data = data.reset_index() + data.to_csv(data_file, index=False) + + data["Date"] = pd.to_datetime(data["Date"]).dt.strftime("%Y-%m-%d") + data = data[data["Date"] <= curr_date].tail(60) + + if len(data) > 0: + result["highs"] = data["High"].tolist() + result["lows"] = data["Low"].tolist() + result["closes"] = data["Close"].tolist() + result["current_price"] = data["Close"].iloc[-1] if len(data) > 0 else None + + except (FileNotFoundError, KeyError, ValueError) as e: + logger.warning("Failed to get OHLC data for %s: %s", ticker, str(e)) + + return result + + +def _get_atr(ticker: str, curr_date: str) -> float | None: + from tradingagents.agents.discovery.indicators.momentum import _get_stock_stats_bulk + + try: + atr_data = _get_stock_stats_bulk(ticker, "atr", curr_date) + if curr_date in atr_data and atr_data[curr_date] != "N/A": + return float(atr_data[curr_date]) + except (KeyError, ValueError, FileNotFoundError) as e: + logger.warning("Failed to get ATR for %s: %s", ticker, str(e)) + + return None + + +def calculate_support_resistance_metrics(ticker: str, curr_date: str) -> dict: + result = { + "support_level": None, + "resistance_level": None, + "atr": None, + "suggested_stop": None, + "reward_target": None, + "risk_reward_ratio": None, + "risk_reward_score": 0.5, + } + + try: + ohlc = _get_ohlc_data(ticker, curr_date) + atr = _get_atr(ticker, curr_date) + + highs = ohlc["highs"] + lows = ohlc["lows"] + closes = ohlc["closes"] + current_price = ohlc["current_price"] + + if not highs or not lows or current_price is None: + return result + + result["atr"] = atr + + support_20d, support_50d = find_support_levels(lows, 20, 50) + resistance_20d, resistance_50d = find_resistance_levels(highs, 20, 50) + + swing_lows, swing_highs = detect_swing_points(closes, n_bars=5) + + all_supports = [s for s in [support_20d, support_50d] + swing_lows if s > 0] + all_resistances = [ + r for r in [resistance_20d, resistance_50d] + swing_highs if r > 0 + ] + + nearest_support, nearest_resistance = get_nearest_levels( + current_price, all_supports, all_resistances + ) + + result["support_level"] = ( + nearest_support if nearest_support > 0 else support_20d + ) + result["resistance_level"] = ( + nearest_resistance if nearest_resistance > 0 else resistance_20d + ) + + if atr and atr > 0: + result["suggested_stop"] = calculate_stop_loss( + current_price, atr, multiplier=1.5 + ) + else: + pct_to_support = ( + (current_price - result["support_level"]) / current_price + if result["support_level"] > 0 + else 0.02 + ) + result["suggested_stop"] = current_price * (1 - max(pct_to_support, 0.02)) + + result["reward_target"] = calculate_reward_target( + current_price, result["resistance_level"] + ) + + if result["suggested_stop"] and result["reward_target"]: + result["risk_reward_ratio"] = calculate_risk_reward_ratio( + current_price, result["suggested_stop"], result["reward_target"] + ) + result["risk_reward_score"] = calculate_risk_reward_score( + result["risk_reward_ratio"] + ) + + except (KeyError, ValueError, RuntimeError) as e: + logger.warning( + "Failed to calculate support/resistance metrics for %s: %s", ticker, str(e) + ) + + return result diff --git a/tradingagents/agents/discovery/indicators/timeframe.py b/tradingagents/agents/discovery/indicators/timeframe.py new file mode 100644 index 00000000..5e8d057b --- /dev/null +++ b/tradingagents/agents/discovery/indicators/timeframe.py @@ -0,0 +1,203 @@ +import logging + +logger = logging.getLogger(__name__) + + +def determine_signal_from_indicators( + rsi: float | None, + macd: float | None, + macd_signal: float | None, + price_vs_sma: float | None, + ema_direction: str | None, +) -> str: + bullish_signals = 0 + bearish_signals = 0 + total_signals = 0 + + if rsi is not None: + total_signals += 1 + if rsi < 40: + bullish_signals += 1 + elif rsi > 60: + bearish_signals += 1 + + if macd is not None and macd_signal is not None: + total_signals += 1 + if macd > macd_signal: + bullish_signals += 1 + else: + bearish_signals += 1 + + if price_vs_sma is not None: + total_signals += 1 + if price_vs_sma > 0: + bullish_signals += 1 + elif price_vs_sma < 0: + bearish_signals += 1 + + if ema_direction is not None: + total_signals += 1 + if ema_direction == "up": + bullish_signals += 1 + elif ema_direction == "down": + bearish_signals += 1 + + if total_signals == 0: + return "neutral" + + bullish_ratio = bullish_signals / total_signals + bearish_ratio = bearish_signals / total_signals + + if bullish_ratio >= 0.6: + return "bullish" + elif bearish_ratio >= 0.6: + return "bearish" + else: + return "neutral" + + +def calculate_timeframe_signals( + momentum_data: dict, + relative_strength_data: dict, +) -> dict: + result = { + "short_term_signal": "neutral", + "medium_term_signal": "neutral", + "long_term_signal": "neutral", + "timeframe_alignment": "neutral", + "signal_strength": 0.5, + } + + try: + rsi = momentum_data.get("rsi") + macd = momentum_data.get("macd") + macd_signal = momentum_data.get("macd_signal") + ema_direction = momentum_data.get("ema10_direction") + price_vs_sma50 = momentum_data.get("price_vs_sma50") + price_vs_sma200 = momentum_data.get("price_vs_sma200") + + rs_5d = relative_strength_data.get("rs_vs_spy_5d") + rs_20d = relative_strength_data.get("rs_vs_spy_20d") + rs_60d = relative_strength_data.get("rs_vs_spy_60d") + + short_bullish = 0 + short_bearish = 0 + short_total = 0 + + if rsi is not None: + short_total += 1 + if rsi < 35: + short_bullish += 1 + elif rsi > 65: + short_bearish += 1 + + if ema_direction == "up": + short_total += 1 + short_bullish += 1 + elif ema_direction == "down": + short_total += 1 + short_bearish += 1 + + if rs_5d is not None: + short_total += 1 + if rs_5d > 0: + short_bullish += 1 + elif rs_5d < 0: + short_bearish += 1 + + if short_total > 0: + if short_bullish / short_total >= 0.6: + result["short_term_signal"] = "bullish" + elif short_bearish / short_total >= 0.6: + result["short_term_signal"] = "bearish" + + med_bullish = 0 + med_bearish = 0 + med_total = 0 + + if macd is not None and macd_signal is not None: + med_total += 1 + if macd > macd_signal: + med_bullish += 1 + else: + med_bearish += 1 + + if price_vs_sma50 is not None: + med_total += 1 + if price_vs_sma50 > 0: + med_bullish += 1 + elif price_vs_sma50 < -2: + med_bearish += 1 + + if rs_20d is not None: + med_total += 1 + if rs_20d > 0: + med_bullish += 1 + elif rs_20d < 0: + med_bearish += 1 + + if med_total > 0: + if med_bullish / med_total >= 0.6: + result["medium_term_signal"] = "bullish" + elif med_bearish / med_total >= 0.6: + result["medium_term_signal"] = "bearish" + + long_bullish = 0 + long_bearish = 0 + long_total = 0 + + if price_vs_sma200 is not None: + long_total += 1 + if price_vs_sma200 > 0: + long_bullish += 1 + elif price_vs_sma200 < -5: + long_bearish += 1 + + if rs_60d is not None: + long_total += 1 + if rs_60d > 0: + long_bullish += 1 + elif rs_60d < 0: + long_bearish += 1 + + if price_vs_sma50 is not None and price_vs_sma200 is not None: + long_total += 1 + if price_vs_sma50 > 0 and price_vs_sma200 > 0: + long_bullish += 1 + elif price_vs_sma50 < 0 and price_vs_sma200 < 0: + long_bearish += 1 + + if long_total > 0: + if long_bullish / long_total >= 0.6: + result["long_term_signal"] = "bullish" + elif long_bearish / long_total >= 0.6: + result["long_term_signal"] = "bearish" + + signals = [ + result["short_term_signal"], + result["medium_term_signal"], + result["long_term_signal"], + ] + bullish_count = signals.count("bullish") + bearish_count = signals.count("bearish") + + if bullish_count == 3: + result["timeframe_alignment"] = "aligned_bullish" + result["signal_strength"] = 1.0 + elif bearish_count == 3: + result["timeframe_alignment"] = "aligned_bearish" + result["signal_strength"] = 0.0 + elif bullish_count >= 2: + result["timeframe_alignment"] = "mixed" + result["signal_strength"] = 0.7 + elif bearish_count >= 2: + result["timeframe_alignment"] = "mixed" + result["signal_strength"] = 0.3 + else: + result["timeframe_alignment"] = "neutral" + result["signal_strength"] = 0.5 + + except (KeyError, TypeError, ValueError) as e: + logger.warning("Failed to calculate timeframe signals: %s", str(e)) + + return result diff --git a/tradingagents/agents/discovery/indicators/volume.py b/tradingagents/agents/discovery/indicators/volume.py new file mode 100644 index 00000000..7cb92281 --- /dev/null +++ b/tradingagents/agents/discovery/indicators/volume.py @@ -0,0 +1,213 @@ +import logging +import os + +import numpy as np +import pandas as pd + +from tradingagents.config import get_settings + +logger = logging.getLogger(__name__) + + +def calculate_volume_ratio(current_volume: float, avg_volume_20d: float) -> float: + if avg_volume_20d == 0: + return 0.0 + return current_volume / avg_volume_20d + + +def calculate_volume_trend(volume_series: list[float]) -> tuple[str, float]: + if len(volume_series) < 2: + return ("flat", 0.0) + + x = np.arange(len(volume_series)) + y = np.array(volume_series) + + coeffs = np.polyfit(x, y, 1) + slope = coeffs[0] + + avg_volume = np.mean(y) + if avg_volume == 0: + return ("flat", 0.0) + + normalized_slope = slope / avg_volume + + if normalized_slope > 0.02: + return ("increasing", float(slope)) + elif normalized_slope < -0.02: + return ("decreasing", float(slope)) + else: + return ("flat", float(slope)) + + +def calculate_dollar_volume(price: float, volume: float) -> float: + return price * volume + + +def calculate_volume_score( + volume_ratio: float, + trend: str, + dollar_volume: float, + price_change: float, + min_dollar_volume: float, +) -> float: + score = 0.5 + + if volume_ratio >= 2.0: + if price_change > 0: + score += 0.35 + else: + score += 0.15 + elif volume_ratio >= 1.5: + if price_change > 0: + score += 0.25 + else: + score += 0.10 + elif volume_ratio >= 1.0: + score += 0.05 + else: + score -= 0.1 + + if trend == "increasing": + score += 0.1 + elif trend == "decreasing": + score -= 0.05 + + if dollar_volume < min_dollar_volume: + liquidity_penalty = min( + 0.3, (min_dollar_volume - dollar_volume) / min_dollar_volume * 0.3 + ) + score -= liquidity_penalty + + return max(0.0, min(1.0, score)) + + +def _get_volume_price_data(ticker: str, curr_date: str) -> dict: + from tradingagents.dataflows.config import get_config + + config = get_config() + online = config["data_vendors"]["technical_indicators"] != "local" + + result = { + "volumes": [], + "prices": [], + "current_price": None, + "current_volume": None, + "price_change_pct": 0.0, + } + + try: + if not online: + data = pd.read_csv( + os.path.join( + config.get("data_cache_dir", "data"), + f"{ticker}-YFin-data-2015-01-01-2025-03-25.csv", + ) + ) + else: + import yfinance as yf + + today_date = pd.Timestamp.today() + + end_date = today_date + start_date = today_date - pd.DateOffset(years=15) + start_date_str = start_date.strftime("%Y-%m-%d") + end_date_str = end_date.strftime("%Y-%m-%d") + + os.makedirs(config["data_cache_dir"], exist_ok=True) + + data_file = os.path.join( + config["data_cache_dir"], + f"{ticker}-YFin-data-{start_date_str}-{end_date_str}.csv", + ) + + if os.path.exists(data_file): + data = pd.read_csv(data_file) + else: + data = yf.download( + ticker, + start=start_date_str, + end=end_date_str, + multi_level_index=False, + progress=False, + auto_adjust=True, + ) + data = data.reset_index() + data.to_csv(data_file, index=False) + + data["Date"] = pd.to_datetime(data["Date"]).dt.strftime("%Y-%m-%d") + data = data[data["Date"] <= curr_date].tail(30) + + if len(data) > 0: + volumes = data["Volume"].tolist() + prices = data["Close"].tolist() + + result["volumes"] = volumes[:-1] if len(volumes) > 1 else volumes + result["prices"] = prices[:-1] if len(prices) > 1 else prices + result["current_volume"] = volumes[-1] if volumes else None + result["current_price"] = prices[-1] if prices else None + + if len(prices) >= 2: + result["price_change_pct"] = ( + (prices[-1] - prices[-2]) / prices[-2] * 100 + if prices[-2] != 0 + else 0 + ) + + except (FileNotFoundError, KeyError, ValueError) as e: + logger.warning("Failed to get volume/price data for %s: %s", ticker, str(e)) + + return result + + +def calculate_volume_metrics(ticker: str, curr_date: str) -> dict: + result = { + "volume_ratio": None, + "volume_trend": None, + "volume_trend_slope": None, + "dollar_volume": None, + "volume_score": 0.5, + } + + try: + settings = get_settings() + min_dollar_volume = settings.min_dollar_volume + + data = _get_volume_price_data(ticker, curr_date) + + volumes = data["volumes"] + current_volume = data["current_volume"] + current_price = data["current_price"] + price_change = data["price_change_pct"] + + if not volumes or current_volume is None or current_price is None: + return result + + avg_volume_20d = ( + sum(volumes[-20:]) / min(len(volumes[-20:]), 20) if volumes else 0 + ) + + volume_ratio = calculate_volume_ratio(current_volume, avg_volume_20d) + result["volume_ratio"] = volume_ratio + + trend_volumes = ( + volumes[-10:] + [current_volume] if volumes else [current_volume] + ) + trend, slope = calculate_volume_trend(trend_volumes) + result["volume_trend"] = trend + result["volume_trend_slope"] = slope + + dollar_volume = calculate_dollar_volume(current_price, current_volume) + result["dollar_volume"] = dollar_volume + + result["volume_score"] = calculate_volume_score( + volume_ratio=volume_ratio, + trend=trend, + dollar_volume=dollar_volume, + price_change=price_change, + min_dollar_volume=min_dollar_volume, + ) + + except (KeyError, ValueError, RuntimeError) as e: + logger.warning("Failed to calculate volume metrics for %s: %s", ticker, str(e)) + + return result diff --git a/tradingagents/agents/discovery/models.py b/tradingagents/agents/discovery/models.py index 41be8e91..0e41f722 100644 --- a/tradingagents/agents/discovery/models.py +++ b/tradingagents/agents/discovery/models.py @@ -1,7 +1,12 @@ +from __future__ import annotations + from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from tradingagents.agents.discovery.quantitative_models import QuantitativeMetrics class DiscoveryStatus(Enum): @@ -50,7 +55,7 @@ class NewsArticle: } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "NewsArticle": + def from_dict(cls, data: dict[str, Any]) -> NewsArticle: return cls( title=data["title"], source=data["source"], @@ -72,9 +77,11 @@ class TrendingStock: event_type: EventCategory news_summary: str source_articles: list[NewsArticle] + quantitative_metrics: QuantitativeMetrics | None = None + conviction_score: float | None = None def to_dict(self) -> dict[str, Any]: - return { + result = { "ticker": self.ticker, "company_name": self.company_name, "score": self.score, @@ -85,9 +92,24 @@ class TrendingStock: "news_summary": self.news_summary, "source_articles": [article.to_dict() for article in self.source_articles], } + if self.quantitative_metrics is not None: + result["quantitative_metrics"] = self.quantitative_metrics.to_dict() + if self.conviction_score is not None: + result["conviction_score"] = self.conviction_score + return result @classmethod - def from_dict(cls, data: dict[str, Any]) -> "TrendingStock": + def from_dict(cls, data: dict[str, Any]) -> TrendingStock: + from tradingagents.agents.discovery.quantitative_models import ( + QuantitativeMetrics, + ) + + quantitative_metrics = None + if data.get("quantitative_metrics"): + quantitative_metrics = QuantitativeMetrics.from_dict( + data["quantitative_metrics"] + ) + return cls( ticker=data["ticker"], company_name=data["company_name"], @@ -100,6 +122,8 @@ class TrendingStock: source_articles=[ NewsArticle.from_dict(article) for article in data["source_articles"] ], + quantitative_metrics=quantitative_metrics, + conviction_score=data.get("conviction_score"), ) @@ -125,7 +149,7 @@ class DiscoveryRequest: } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "DiscoveryRequest": + def from_dict(cls, data: dict[str, Any]) -> DiscoveryRequest: return cls( lookback_period=data["lookback_period"], sector_filter=( @@ -165,7 +189,7 @@ class DiscoveryResult: } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "DiscoveryResult": + def from_dict(cls, data: dict[str, Any]) -> DiscoveryResult: return cls( request=DiscoveryRequest.from_dict(data["request"]), trending_stocks=[ diff --git a/tradingagents/agents/discovery/quantitative_cache.py b/tradingagents/agents/discovery/quantitative_cache.py new file mode 100644 index 00000000..f46fe0ae --- /dev/null +++ b/tradingagents/agents/discovery/quantitative_cache.py @@ -0,0 +1,47 @@ +import logging +from collections import OrderedDict +from datetime import datetime +from typing import Any + +import pandas as pd + +logger = logging.getLogger(__name__) + +MAX_CACHE_SIZE = 100 + +_price_data_cache: OrderedDict[str, tuple[pd.DataFrame, datetime]] = OrderedDict() + + +def get_cached_price_data(ticker: str) -> pd.DataFrame | None: + if ticker not in _price_data_cache: + return None + + df, timestamp = _price_data_cache[ticker] + + _price_data_cache.move_to_end(ticker) + + return df.copy() + + +def set_cached_price_data(ticker: str, data: pd.DataFrame) -> None: + global _price_data_cache + + if ticker in _price_data_cache: + _price_data_cache.move_to_end(ticker) + _price_data_cache[ticker] = (data.copy(), datetime.now()) + return + + while len(_price_data_cache) >= MAX_CACHE_SIZE: + _price_data_cache.popitem(last=False) + + _price_data_cache[ticker] = (data.copy(), datetime.now()) + logger.debug( + "Cached price data for %s (cache size: %d)", ticker, len(_price_data_cache) + ) + + +def clear_run_cache() -> None: + global _price_data_cache + count = len(_price_data_cache) + _price_data_cache.clear() + logger.debug("Cleared run cache (%d entries removed)", count) diff --git a/tradingagents/agents/discovery/quantitative_models.py b/tradingagents/agents/discovery/quantitative_models.py new file mode 100644 index 00000000..f872c0d7 --- /dev/null +++ b/tradingagents/agents/discovery/quantitative_models.py @@ -0,0 +1,91 @@ +from typing import Any + +from pydantic import BaseModel, Field, field_validator + + +class QuantitativeMetrics(BaseModel): + momentum_score: float = Field(ge=0.0, le=1.0) + volume_score: float = Field(ge=0.0, le=1.0) + relative_strength_score: float = Field(ge=0.0, le=1.0) + risk_reward_score: float = Field(ge=0.0, le=1.0) + + rsi: float | None = None + macd: float | None = None + macd_signal: float | None = None + macd_histogram: float | None = None + + price_vs_sma50: float | None = None + price_vs_sma200: float | None = None + ema10_direction: str | None = None + + volume_ratio: float | None = None + volume_trend: str | None = None + dollar_volume: float | None = None + + rs_vs_spy_5d: float | None = None + rs_vs_spy_20d: float | None = None + rs_vs_spy_60d: float | None = None + rs_vs_sector: float | None = None + sector_etf: str | None = None + + support_level: float | None = None + resistance_level: float | None = None + atr: float | None = None + suggested_stop: float | None = None + reward_target: float | None = None + risk_reward_ratio: float | None = None + + timeframe_alignment: str | None = None + short_term_signal: str | None = None + medium_term_signal: str | None = None + long_term_signal: str | None = None + signal_strength: float | None = None + + quantitative_score: float = Field(ge=0.0, le=1.0) + + @field_validator("ema10_direction") + @classmethod + def validate_ema10_direction(cls, v: str | None) -> str | None: + if v is None: + return v + valid_directions = {"up", "down", "flat"} + if v not in valid_directions: + raise ValueError(f"ema10_direction must be one of {valid_directions}") + return v + + @field_validator("volume_trend") + @classmethod + def validate_volume_trend(cls, v: str | None) -> str | None: + if v is None: + return v + valid_trends = {"increasing", "decreasing", "flat"} + if v not in valid_trends: + raise ValueError(f"volume_trend must be one of {valid_trends}") + return v + + @field_validator("timeframe_alignment") + @classmethod + def validate_timeframe_alignment(cls, v: str | None) -> str | None: + if v is None: + return v + valid_alignments = {"aligned_bullish", "aligned_bearish", "mixed", "neutral"} + if v not in valid_alignments: + raise ValueError(f"timeframe_alignment must be one of {valid_alignments}") + return v + + @field_validator("short_term_signal", "medium_term_signal", "long_term_signal") + @classmethod + def validate_signal(cls, v: str | None) -> str | None: + if v is None: + return v + valid_signals = {"bullish", "bearish", "neutral"} + if v not in valid_signals: + raise ValueError(f"signal must be one of {valid_signals}") + return v + + def to_dict(self) -> dict[str, Any]: + return self.model_dump() + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "QuantitativeMetrics": + return cls(**data) diff --git a/tradingagents/agents/discovery/quantitative_scorer.py b/tradingagents/agents/discovery/quantitative_scorer.py new file mode 100644 index 00000000..a97490c3 --- /dev/null +++ b/tradingagents/agents/discovery/quantitative_scorer.py @@ -0,0 +1,178 @@ +import logging +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import TYPE_CHECKING + +from tradingagents.agents.discovery.indicators import ( + calculate_momentum_score, + calculate_relative_strength_metrics, + calculate_support_resistance_metrics, + calculate_volume_metrics, +) +from tradingagents.agents.discovery.indicators.timeframe import ( + calculate_timeframe_signals, +) +from tradingagents.agents.discovery.quantitative_cache import clear_run_cache +from tradingagents.agents.discovery.quantitative_models import QuantitativeMetrics +from tradingagents.config import QuantitativeWeightsConfig, get_settings + +if TYPE_CHECKING: + from tradingagents.agents.discovery.models import TrendingStock + +logger = logging.getLogger(__name__) + + +def calculate_unified_score( + momentum: float, + volume: float, + rs: float, + rr: float, + weights: QuantitativeWeightsConfig, +) -> float: + score = ( + momentum * weights.momentum_weight + + volume * weights.volume_weight + + rs * weights.relative_strength_weight + + rr * weights.risk_reward_weight + ) + return max(0.0, min(1.0, score)) + + +def calculate_single_stock_metrics( + ticker: str, + curr_date: str, +) -> QuantitativeMetrics | None: + try: + settings = get_settings() + weights = settings.quantitative_weights + + momentum = calculate_momentum_score(ticker, curr_date) + volume = calculate_volume_metrics(ticker, curr_date) + rs = calculate_relative_strength_metrics(ticker, curr_date) + sr = calculate_support_resistance_metrics(ticker, curr_date) + + timeframe = calculate_timeframe_signals(momentum, rs) + + momentum_score = momentum.get("momentum_score", 0.5) + volume_score = volume.get("volume_score", 0.5) + rs_score = rs.get("relative_strength_score", 0.5) + rr_score = sr.get("risk_reward_score", 0.5) + + unified_score = calculate_unified_score( + momentum=momentum_score, + volume=volume_score, + rs=rs_score, + rr=rr_score, + weights=weights, + ) + + return QuantitativeMetrics( + momentum_score=momentum_score, + volume_score=volume_score, + relative_strength_score=rs_score, + risk_reward_score=rr_score, + rsi=momentum.get("rsi"), + macd=momentum.get("macd"), + macd_signal=momentum.get("macd_signal"), + macd_histogram=momentum.get("macd_histogram"), + price_vs_sma50=momentum.get("price_vs_sma50"), + price_vs_sma200=momentum.get("price_vs_sma200"), + ema10_direction=momentum.get("ema10_direction"), + volume_ratio=volume.get("volume_ratio"), + volume_trend=volume.get("volume_trend"), + dollar_volume=volume.get("dollar_volume"), + rs_vs_spy_5d=rs.get("rs_vs_spy_5d"), + rs_vs_spy_20d=rs.get("rs_vs_spy_20d"), + rs_vs_spy_60d=rs.get("rs_vs_spy_60d"), + rs_vs_sector=rs.get("rs_vs_sector"), + sector_etf=rs.get("sector_etf"), + support_level=sr.get("support_level"), + resistance_level=sr.get("resistance_level"), + atr=sr.get("atr"), + suggested_stop=sr.get("suggested_stop"), + reward_target=sr.get("reward_target"), + risk_reward_ratio=sr.get("risk_reward_ratio"), + timeframe_alignment=timeframe.get("timeframe_alignment"), + short_term_signal=timeframe.get("short_term_signal"), + medium_term_signal=timeframe.get("medium_term_signal"), + long_term_signal=timeframe.get("long_term_signal"), + signal_strength=timeframe.get("signal_strength"), + quantitative_score=unified_score, + ) + + except Exception as e: + logger.warning( + "Failed to calculate quantitative metrics for %s: %s", ticker, str(e) + ) + return None + + +def _normalize_score(score: float, max_score: float) -> float: + if max_score == 0: + return 0.0 + return min(1.0, score / max_score) + + +def enhance_with_quantitative_scores( + stocks: list["TrendingStock"], + curr_date: str, + max_stocks: int = 50, +) -> list["TrendingStock"]: + if not stocks: + return [] + + settings = get_settings() + weights = settings.quantitative_weights + + sorted_stocks = sorted(stocks, key=lambda s: s.score, reverse=True) + stocks_to_process = sorted_stocks[:max_stocks] + remaining_stocks = sorted_stocks[max_stocks:] + + clear_run_cache() + + results: dict[str, QuantitativeMetrics | None] = {} + + with ThreadPoolExecutor(max_workers=10) as executor: + future_to_ticker = { + executor.submit( + calculate_single_stock_metrics, stock.ticker, curr_date + ): stock.ticker + for stock in stocks_to_process + } + + for future in as_completed(future_to_ticker): + ticker = future_to_ticker[future] + try: + results[ticker] = future.result() + except Exception as e: + logger.warning("Quantitative scoring failed for %s: %s", ticker, str(e)) + results[ticker] = None + + max_news_score = max((s.score for s in stocks_to_process), default=1.0) or 1.0 + + for stock in stocks_to_process: + metrics = results.get(stock.ticker) + stock.quantitative_metrics = metrics + + news_normalized = _normalize_score(stock.score, max_news_score) + + if metrics is not None: + stock.conviction_score = ( + weights.news_sentiment_weight * news_normalized + + weights.quantitative_weight * metrics.quantitative_score + ) + else: + stock.conviction_score = news_normalized * weights.news_sentiment_weight + + for stock in remaining_stocks: + stock.quantitative_metrics = None + stock.conviction_score = None + + enhanced_stocks = sorted( + stocks_to_process, + key=lambda s: s.conviction_score if s.conviction_score is not None else 0.0, + reverse=True, + ) + + clear_run_cache() + + return enhanced_stocks + remaining_stocks diff --git a/tradingagents/config.py b/tradingagents/config.py index 16d9770c..981e87f1 100644 --- a/tradingagents/config.py +++ b/tradingagents/config.py @@ -1,7 +1,7 @@ import os -from typing import Any, Dict, List, Optional +from typing import Any -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from pydantic_settings import BaseSettings @@ -12,6 +12,39 @@ class DataVendorsConfig(BaseModel): news_data: str = "alpha_vantage" +class QuantitativeWeightsConfig(BaseModel): + news_sentiment_weight: float = Field(default=0.50, ge=0.0, le=1.0) + quantitative_weight: float = Field(default=0.50, ge=0.0, le=1.0) + + momentum_weight: float = Field(default=0.30, ge=0.0, le=1.0) + volume_weight: float = Field(default=0.25, ge=0.0, le=1.0) + relative_strength_weight: float = Field(default=0.25, ge=0.0, le=1.0) + risk_reward_weight: float = Field(default=0.20, ge=0.0, le=1.0) + + @model_validator(mode="after") + def validate_weights_sum(self) -> "QuantitativeWeightsConfig": + top_level_sum = self.news_sentiment_weight + self.quantitative_weight + if abs(top_level_sum - 1.0) > 0.01: + raise ValueError( + f"Top-level weights (news_sentiment_weight + quantitative_weight) " + f"must sum to 1.0, got {top_level_sum}" + ) + + sub_weights_sum = ( + self.momentum_weight + + self.volume_weight + + self.relative_strength_weight + + self.risk_reward_weight + ) + if abs(sub_weights_sum - 1.0) > 0.01: + raise ValueError( + f"Sub-weights (momentum + volume + relative_strength + risk_reward) " + f"must sum to 1.0, got {sub_weights_sum}" + ) + + return self + + class TradingAgentsSettings(BaseSettings): project_dir: str = Field( default_factory=lambda: os.path.abspath( @@ -58,6 +91,14 @@ class TradingAgentsSettings(BaseSettings): data_vendors: DataVendorsConfig = Field(default_factory=DataVendorsConfig) tool_vendors: dict[str, Any] = Field(default_factory=dict) + quantitative_weights: QuantitativeWeightsConfig = Field( + default_factory=QuantitativeWeightsConfig + ) + quantitative_max_stocks: int = Field(default=50, ge=10, le=100) + quantitative_cache_ttl_intraday: int = Field(default=1, ge=1) + quantitative_cache_ttl_relative_strength: int = Field(default=4, ge=1) + min_dollar_volume: float = Field(default=1_000_000.0, ge=0.0) + model_config = { "env_prefix": "TRADINGAGENTS_", "env_nested_delimiter": "__", @@ -104,6 +145,7 @@ class TradingAgentsSettings(BaseSettings): def to_dict(self) -> dict[str, Any]: result = self.model_dump() result["data_vendors"] = self.data_vendors.model_dump() + result["quantitative_weights"] = self.quantitative_weights.model_dump() return result def get_api_key(self, vendor: str) -> str | None: diff --git a/tradingagents/database/services/market_data.py b/tradingagents/database/services/market_data.py index 7a837604..c40bf6d2 100644 --- a/tradingagents/database/services/market_data.py +++ b/tradingagents/database/services/market_data.py @@ -132,6 +132,11 @@ DEFAULT_TTL_HOURS = { "get_insider_sentiment": 24, "get_insider_transactions": 24, "get_bulk_news": 1, + "quant_indicators": 1, + "volume_analysis": 1, + "relative_strength": 4, + "support_resistance": 1, + "risk_reward": 1, } diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 38074098..93366c84 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -22,6 +22,9 @@ from tradingagents.agents.discovery import ( calculate_trending_scores, extract_entities, ) +from tradingagents.agents.discovery.quantitative_scorer import ( + enhance_with_quantitative_scores, +) from tradingagents.agents.utils.agent_utils import ( get_balance_sheet, get_cashflow, @@ -303,12 +306,17 @@ class TradingAgentsGraph: ) hard_timeout = self.config.get("discovery_hard_timeout", 120) + enable_quantitative = self.config.get("enable_quantitative_filtering", True) + quantitative_max_stocks = self.config.get("quantitative_max_stocks", 50) discovery_result = {"stocks": [], "error": None} def run_discovery(): try: articles = get_bulk_news(request.lookback_period) + if not articles: + discovery_result["error"] = "No articles returned from news sources" + return mentions = extract_entities(articles, self.config) @@ -326,6 +334,19 @@ class TradingAgentsGraph: min_mentions=min_mentions, ) + if enable_quantitative and trending_stocks: + curr_date = date.today().strftime("%Y-%m-%d") + logger.info( + "Enhancing %d stocks with quantitative scores (max: %d)", + len(trending_stocks), + quantitative_max_stocks, + ) + trending_stocks = enhance_with_quantitative_scores( + trending_stocks, + curr_date, + max_stocks=quantitative_max_stocks, + ) + discovery_result["stocks"] = trending_stocks except ( ValueError, diff --git a/uv.lock b/uv.lock index 81beae40..aa132adc 100644 --- a/uv.lock +++ b/uv.lock @@ -5149,6 +5149,7 @@ dependencies = [ { name = "requests" }, { name = "rich" }, { name = "setuptools" }, + { name = "sqlalchemy" }, { name = "stockstats" }, { name = "tqdm" }, { name = "tushare" }, @@ -5196,6 +5197,7 @@ requires-dist = [ { name = "rich", specifier = ">=14.0.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.8.2" }, { name = "setuptools", specifier = ">=80.9.0" }, + { name = "sqlalchemy", specifier = ">=2.0.0" }, { name = "stockstats", specifier = ">=0.6.5" }, { name = "tqdm", specifier = ">=4.67.1" }, { name = "tushare", specifier = ">=1.4.21" },