TradingAgents/tests/discovery/test_integration.py

490 lines
18 KiB
Python

import pytest
import math
from datetime import datetime, timedelta
from unittest.mock import patch, MagicMock
from tradingagents.agents.discovery import (
TrendingStock,
NewsArticle,
DiscoveryRequest,
DiscoveryResult,
DiscoveryStatus,
Sector,
EventCategory,
DiscoveryTimeoutError,
NewsUnavailableError,
)
from tradingagents.agents.discovery.entity_extractor import EntityMention
class TestEndToEndDiscoveryFlow:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
def test_full_discovery_flow_from_news_to_results(
self, mock_scores, mock_extract, mock_bulk_news
):
now = datetime.now()
mock_articles = [
NewsArticle(
title="Apple announces record earnings",
source="Reuters",
url="https://reuters.com/apple-earnings",
published_at=now - timedelta(hours=2),
content_snippet="Apple Inc reported record quarterly earnings...",
ticker_mentions=["AAPL"],
),
NewsArticle(
title="Apple stock surges on AI news",
source="Bloomberg",
url="https://bloomberg.com/apple-ai",
published_at=now - timedelta(hours=1),
content_snippet="Shares of Apple jumped after AI announcement...",
ticker_mentions=["AAPL"],
),
]
mock_bulk_news.return_value = mock_articles
mock_mentions = [
EntityMention(
company_name="Apple Inc",
confidence=0.95,
context_snippet="Apple Inc reported record quarterly earnings",
article_id="article_0",
event_type=EventCategory.EARNINGS,
sentiment=0.8,
),
EntityMention(
company_name="Apple",
confidence=0.92,
context_snippet="Shares of Apple jumped",
article_id="article_1",
event_type=EventCategory.PRODUCT_LAUNCH,
sentiment=0.7,
),
]
mock_extract.return_value = mock_mentions
mock_trending = [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc.",
score=8.5,
mention_count=2,
sentiment=0.75,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Apple reported record earnings and AI progress.",
source_articles=mock_articles,
),
]
mock_scores.return_value = mock_trending
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
request = DiscoveryRequest(lookback_period="24h")
result = graph.discover_trending(request)
assert isinstance(result, DiscoveryResult)
assert result.status == DiscoveryStatus.COMPLETED
assert len(result.trending_stocks) == 1
assert result.trending_stocks[0].ticker == "AAPL"
assert result.trending_stocks[0].mention_count >= 2
mock_bulk_news.assert_called_once_with("24h")
mock_extract.assert_called_once()
mock_scores.assert_called_once()
class TestEntityExtractionToScoringPipeline:
def test_pipeline_from_extraction_to_scoring(self):
from tradingagents.agents.discovery.scorer import calculate_trending_scores
now = datetime.now()
articles = [
NewsArticle(
title="Microsoft cloud revenue grows",
source="WSJ",
url="https://wsj.com/article1",
published_at=now - timedelta(hours=2),
content_snippet="Microsoft Corporation reported strong cloud growth.",
ticker_mentions=["MSFT"],
),
NewsArticle(
title="Microsoft earnings beat estimates",
source="CNBC",
url="https://cnbc.com/article2",
published_at=now - timedelta(hours=3),
content_snippet="Microsoft earnings exceeded analyst expectations.",
ticker_mentions=["MSFT"],
),
NewsArticle(
title="Tech stocks rally",
source="Bloomberg",
url="https://bloomberg.com/article3",
published_at=now - timedelta(hours=1),
content_snippet="Technology companies led market gains.",
ticker_mentions=[],
),
]
mentions = [
EntityMention(
company_name="Microsoft Corporation",
confidence=0.95,
context_snippet="Microsoft Corporation reported strong cloud growth",
article_id="article_0",
event_type=EventCategory.EARNINGS,
sentiment=0.7,
),
EntityMention(
company_name="Microsoft",
confidence=0.92,
context_snippet="Microsoft earnings exceeded analyst expectations",
article_id="article_1",
event_type=EventCategory.EARNINGS,
sentiment=0.8,
),
]
with patch("tradingagents.agents.discovery.scorer.resolve_ticker") as mock_resolve:
mock_resolve.return_value = "MSFT"
with patch("tradingagents.agents.discovery.scorer.classify_sector") as mock_sector:
mock_sector.return_value = "technology"
result = calculate_trending_scores(mentions, articles, min_mentions=2)
assert len(result) == 1
assert result[0].ticker == "MSFT"
assert result[0].mention_count == 2
assert result[0].sentiment > 0
class TestNewsVendorFailureGracefulDegradation:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
def test_news_vendor_failure_with_graceful_degradation(self, mock_bulk_news):
mock_bulk_news.side_effect = NewsUnavailableError("All news vendors failed")
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
result = graph.discover_trending()
assert result.status == DiscoveryStatus.FAILED
assert result.error_message is not None
assert "news" in result.error_message.lower() or "vendor" in result.error_message.lower()
class TestTimeoutHandlingWithPartialResults:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
def test_timeout_handling_returns_error(self, mock_bulk_news):
def slow_fetch(*args, **kwargs):
import time
time.sleep(0.3)
return []
mock_bulk_news.side_effect = slow_fetch
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 0.1,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
with pytest.raises(DiscoveryTimeoutError):
graph.discover_trending()
class TestNoTrendingStocksFound:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
def test_no_trending_stocks_found_returns_empty_list(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = [
NewsArticle(
title="General market update",
source="Reuters",
url="https://reuters.com/general",
published_at=datetime.now(),
content_snippet="Markets were quiet today with no major news.",
ticker_mentions=[],
),
]
mock_extract.return_value = []
mock_scores.return_value = []
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
result = graph.discover_trending()
assert result.status == DiscoveryStatus.COMPLETED
assert len(result.trending_stocks) == 0
class TestAllStocksFilteredOutBySectorFilter:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
def test_all_stocks_filtered_out_by_sector_filter(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = []
mock_extract.return_value = []
mock_scores.return_value = [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc.",
score=10.0,
mention_count=5,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Apple earnings",
source_articles=[],
),
TrendingStock(
ticker="MSFT",
company_name="Microsoft",
score=9.0,
mention_count=4,
sentiment=0.4,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.PRODUCT_LAUNCH,
news_summary="Microsoft product",
source_articles=[],
),
]
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
request = DiscoveryRequest(
lookback_period="24h",
sector_filter=[Sector.HEALTHCARE],
)
result = graph.discover_trending(request)
assert result.status == DiscoveryStatus.COMPLETED
assert len(result.trending_stocks) == 0
class TestAllStocksFilteredOutByEventFilter:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
def test_all_stocks_filtered_out_by_event_filter(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = []
mock_extract.return_value = []
mock_scores.return_value = [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc.",
score=10.0,
mention_count=5,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Apple earnings",
source_articles=[],
),
]
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
request = DiscoveryRequest(
lookback_period="24h",
event_filter=[EventCategory.MERGER_ACQUISITION],
)
result = graph.discover_trending(request)
assert result.status == DiscoveryStatus.COMPLETED
assert len(result.trending_stocks) == 0
class TestMultipleSectorsAndEventsFiltering:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
def test_combined_sector_and_event_filtering(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = []
mock_extract.return_value = []
mock_scores.return_value = [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc.",
score=10.0,
mention_count=5,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Apple earnings",
source_articles=[],
),
TrendingStock(
ticker="JPM",
company_name="JPMorgan Chase",
score=9.0,
mention_count=4,
sentiment=0.4,
sector=Sector.FINANCE,
event_type=EventCategory.EARNINGS,
news_summary="JPM earnings",
source_articles=[],
),
TrendingStock(
ticker="XOM",
company_name="Exxon Mobil",
score=8.0,
mention_count=3,
sentiment=0.3,
sector=Sector.ENERGY,
event_type=EventCategory.REGULATORY,
news_summary="XOM regulatory news",
source_articles=[],
),
]
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
request = DiscoveryRequest(
lookback_period="24h",
sector_filter=[Sector.TECHNOLOGY, Sector.FINANCE],
event_filter=[EventCategory.EARNINGS],
)
result = graph.discover_trending(request)
assert result.status == DiscoveryStatus.COMPLETED
assert len(result.trending_stocks) == 2
tickers = [s.ticker for s in result.trending_stocks]
assert "AAPL" in tickers
assert "JPM" in tickers
assert "XOM" not in tickers
class TestDiscoveryResultPersistenceIntegration:
def test_discovery_result_can_be_serialized_and_saved(self):
from tradingagents.agents.discovery.persistence import (
save_discovery_result,
generate_markdown_summary,
)
import tempfile
import shutil
from pathlib import Path
article = NewsArticle(
title="Test article",
source="Test",
url="https://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["TEST"],
)
stock = TrendingStock(
ticker="TEST",
company_name="Test Company",
score=5.0,
mention_count=2,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.OTHER,
news_summary="Test news summary",
source_articles=[article],
)
request = DiscoveryRequest(
lookback_period="24h",
created_at=datetime.now(),
)
result = DiscoveryResult(
request=request,
trending_stocks=[stock],
status=DiscoveryStatus.COMPLETED,
started_at=datetime.now(),
completed_at=datetime.now(),
)
temp_dir = tempfile.mkdtemp()
try:
path = save_discovery_result(result, base_path=Path(temp_dir))
assert path.exists()
assert (path / "discovery_result.json").exists()
assert (path / "discovery_summary.md").exists()
markdown = generate_markdown_summary(result)
assert "TEST" in markdown
assert "Test Company" in markdown
finally:
shutil.rmtree(temp_dir)