490 lines
18 KiB
Python
490 lines
18 KiB
Python
import pytest
|
|
import math
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import patch, MagicMock
|
|
from tradingagents.agents.discovery import (
|
|
TrendingStock,
|
|
NewsArticle,
|
|
DiscoveryRequest,
|
|
DiscoveryResult,
|
|
DiscoveryStatus,
|
|
Sector,
|
|
EventCategory,
|
|
DiscoveryTimeoutError,
|
|
NewsUnavailableError,
|
|
)
|
|
from tradingagents.agents.discovery.entity_extractor import EntityMention
|
|
|
|
|
|
class TestEndToEndDiscoveryFlow:
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.extract_entities")
|
|
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
|
def test_full_discovery_flow_from_news_to_results(
|
|
self, mock_scores, mock_extract, mock_bulk_news
|
|
):
|
|
now = datetime.now()
|
|
mock_articles = [
|
|
NewsArticle(
|
|
title="Apple announces record earnings",
|
|
source="Reuters",
|
|
url="https://reuters.com/apple-earnings",
|
|
published_at=now - timedelta(hours=2),
|
|
content_snippet="Apple Inc reported record quarterly earnings...",
|
|
ticker_mentions=["AAPL"],
|
|
),
|
|
NewsArticle(
|
|
title="Apple stock surges on AI news",
|
|
source="Bloomberg",
|
|
url="https://bloomberg.com/apple-ai",
|
|
published_at=now - timedelta(hours=1),
|
|
content_snippet="Shares of Apple jumped after AI announcement...",
|
|
ticker_mentions=["AAPL"],
|
|
),
|
|
]
|
|
mock_bulk_news.return_value = mock_articles
|
|
|
|
mock_mentions = [
|
|
EntityMention(
|
|
company_name="Apple Inc",
|
|
confidence=0.95,
|
|
context_snippet="Apple Inc reported record quarterly earnings",
|
|
article_id="article_0",
|
|
event_type=EventCategory.EARNINGS,
|
|
sentiment=0.8,
|
|
),
|
|
EntityMention(
|
|
company_name="Apple",
|
|
confidence=0.92,
|
|
context_snippet="Shares of Apple jumped",
|
|
article_id="article_1",
|
|
event_type=EventCategory.PRODUCT_LAUNCH,
|
|
sentiment=0.7,
|
|
),
|
|
]
|
|
mock_extract.return_value = mock_mentions
|
|
|
|
mock_trending = [
|
|
TrendingStock(
|
|
ticker="AAPL",
|
|
company_name="Apple Inc.",
|
|
score=8.5,
|
|
mention_count=2,
|
|
sentiment=0.75,
|
|
sector=Sector.TECHNOLOGY,
|
|
event_type=EventCategory.EARNINGS,
|
|
news_summary="Apple reported record earnings and AI progress.",
|
|
source_articles=mock_articles,
|
|
),
|
|
]
|
|
mock_scores.return_value = mock_trending
|
|
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
|
graph = TradingAgentsGraph()
|
|
graph.config = {
|
|
"discovery_timeout": 60,
|
|
"discovery_hard_timeout": 120,
|
|
"discovery_cache_ttl": 300,
|
|
"discovery_max_results": 20,
|
|
"discovery_min_mentions": 2,
|
|
}
|
|
|
|
request = DiscoveryRequest(lookback_period="24h")
|
|
result = graph.discover_trending(request)
|
|
|
|
assert isinstance(result, DiscoveryResult)
|
|
assert result.status == DiscoveryStatus.COMPLETED
|
|
assert len(result.trending_stocks) == 1
|
|
assert result.trending_stocks[0].ticker == "AAPL"
|
|
assert result.trending_stocks[0].mention_count >= 2
|
|
|
|
mock_bulk_news.assert_called_once_with("24h")
|
|
mock_extract.assert_called_once()
|
|
mock_scores.assert_called_once()
|
|
|
|
|
|
class TestEntityExtractionToScoringPipeline:
|
|
def test_pipeline_from_extraction_to_scoring(self):
|
|
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
|
|
|
now = datetime.now()
|
|
articles = [
|
|
NewsArticle(
|
|
title="Microsoft cloud revenue grows",
|
|
source="WSJ",
|
|
url="https://wsj.com/article1",
|
|
published_at=now - timedelta(hours=2),
|
|
content_snippet="Microsoft Corporation reported strong cloud growth.",
|
|
ticker_mentions=["MSFT"],
|
|
),
|
|
NewsArticle(
|
|
title="Microsoft earnings beat estimates",
|
|
source="CNBC",
|
|
url="https://cnbc.com/article2",
|
|
published_at=now - timedelta(hours=3),
|
|
content_snippet="Microsoft earnings exceeded analyst expectations.",
|
|
ticker_mentions=["MSFT"],
|
|
),
|
|
NewsArticle(
|
|
title="Tech stocks rally",
|
|
source="Bloomberg",
|
|
url="https://bloomberg.com/article3",
|
|
published_at=now - timedelta(hours=1),
|
|
content_snippet="Technology companies led market gains.",
|
|
ticker_mentions=[],
|
|
),
|
|
]
|
|
|
|
mentions = [
|
|
EntityMention(
|
|
company_name="Microsoft Corporation",
|
|
confidence=0.95,
|
|
context_snippet="Microsoft Corporation reported strong cloud growth",
|
|
article_id="article_0",
|
|
event_type=EventCategory.EARNINGS,
|
|
sentiment=0.7,
|
|
),
|
|
EntityMention(
|
|
company_name="Microsoft",
|
|
confidence=0.92,
|
|
context_snippet="Microsoft earnings exceeded analyst expectations",
|
|
article_id="article_1",
|
|
event_type=EventCategory.EARNINGS,
|
|
sentiment=0.8,
|
|
),
|
|
]
|
|
|
|
with patch("tradingagents.agents.discovery.scorer.resolve_ticker") as mock_resolve:
|
|
mock_resolve.return_value = "MSFT"
|
|
|
|
with patch("tradingagents.agents.discovery.scorer.classify_sector") as mock_sector:
|
|
mock_sector.return_value = "technology"
|
|
|
|
result = calculate_trending_scores(mentions, articles, min_mentions=2)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].ticker == "MSFT"
|
|
assert result[0].mention_count == 2
|
|
assert result[0].sentiment > 0
|
|
|
|
|
|
class TestNewsVendorFailureGracefulDegradation:
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
def test_news_vendor_failure_with_graceful_degradation(self, mock_bulk_news):
|
|
mock_bulk_news.side_effect = NewsUnavailableError("All news vendors failed")
|
|
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
|
graph = TradingAgentsGraph()
|
|
graph.config = {
|
|
"discovery_timeout": 60,
|
|
"discovery_hard_timeout": 120,
|
|
"discovery_cache_ttl": 300,
|
|
"discovery_max_results": 20,
|
|
"discovery_min_mentions": 2,
|
|
}
|
|
|
|
result = graph.discover_trending()
|
|
|
|
assert result.status == DiscoveryStatus.FAILED
|
|
assert result.error_message is not None
|
|
assert "news" in result.error_message.lower() or "vendor" in result.error_message.lower()
|
|
|
|
|
|
class TestTimeoutHandlingWithPartialResults:
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
def test_timeout_handling_returns_error(self, mock_bulk_news):
|
|
def slow_fetch(*args, **kwargs):
|
|
import time
|
|
time.sleep(0.3)
|
|
return []
|
|
|
|
mock_bulk_news.side_effect = slow_fetch
|
|
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
|
graph = TradingAgentsGraph()
|
|
graph.config = {
|
|
"discovery_timeout": 60,
|
|
"discovery_hard_timeout": 0.1,
|
|
"discovery_cache_ttl": 300,
|
|
"discovery_max_results": 20,
|
|
"discovery_min_mentions": 2,
|
|
}
|
|
|
|
with pytest.raises(DiscoveryTimeoutError):
|
|
graph.discover_trending()
|
|
|
|
|
|
class TestNoTrendingStocksFound:
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.extract_entities")
|
|
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
|
def test_no_trending_stocks_found_returns_empty_list(
|
|
self, mock_scores, mock_extract, mock_bulk_news
|
|
):
|
|
mock_bulk_news.return_value = [
|
|
NewsArticle(
|
|
title="General market update",
|
|
source="Reuters",
|
|
url="https://reuters.com/general",
|
|
published_at=datetime.now(),
|
|
content_snippet="Markets were quiet today with no major news.",
|
|
ticker_mentions=[],
|
|
),
|
|
]
|
|
mock_extract.return_value = []
|
|
mock_scores.return_value = []
|
|
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
|
graph = TradingAgentsGraph()
|
|
graph.config = {
|
|
"discovery_timeout": 60,
|
|
"discovery_hard_timeout": 120,
|
|
"discovery_cache_ttl": 300,
|
|
"discovery_max_results": 20,
|
|
"discovery_min_mentions": 2,
|
|
}
|
|
|
|
result = graph.discover_trending()
|
|
|
|
assert result.status == DiscoveryStatus.COMPLETED
|
|
assert len(result.trending_stocks) == 0
|
|
|
|
|
|
class TestAllStocksFilteredOutBySectorFilter:
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.extract_entities")
|
|
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
|
def test_all_stocks_filtered_out_by_sector_filter(
|
|
self, mock_scores, mock_extract, mock_bulk_news
|
|
):
|
|
mock_bulk_news.return_value = []
|
|
mock_extract.return_value = []
|
|
mock_scores.return_value = [
|
|
TrendingStock(
|
|
ticker="AAPL",
|
|
company_name="Apple Inc.",
|
|
score=10.0,
|
|
mention_count=5,
|
|
sentiment=0.5,
|
|
sector=Sector.TECHNOLOGY,
|
|
event_type=EventCategory.EARNINGS,
|
|
news_summary="Apple earnings",
|
|
source_articles=[],
|
|
),
|
|
TrendingStock(
|
|
ticker="MSFT",
|
|
company_name="Microsoft",
|
|
score=9.0,
|
|
mention_count=4,
|
|
sentiment=0.4,
|
|
sector=Sector.TECHNOLOGY,
|
|
event_type=EventCategory.PRODUCT_LAUNCH,
|
|
news_summary="Microsoft product",
|
|
source_articles=[],
|
|
),
|
|
]
|
|
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
|
graph = TradingAgentsGraph()
|
|
graph.config = {
|
|
"discovery_timeout": 60,
|
|
"discovery_hard_timeout": 120,
|
|
"discovery_cache_ttl": 300,
|
|
"discovery_max_results": 20,
|
|
"discovery_min_mentions": 2,
|
|
}
|
|
|
|
request = DiscoveryRequest(
|
|
lookback_period="24h",
|
|
sector_filter=[Sector.HEALTHCARE],
|
|
)
|
|
result = graph.discover_trending(request)
|
|
|
|
assert result.status == DiscoveryStatus.COMPLETED
|
|
assert len(result.trending_stocks) == 0
|
|
|
|
|
|
class TestAllStocksFilteredOutByEventFilter:
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.extract_entities")
|
|
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
|
def test_all_stocks_filtered_out_by_event_filter(
|
|
self, mock_scores, mock_extract, mock_bulk_news
|
|
):
|
|
mock_bulk_news.return_value = []
|
|
mock_extract.return_value = []
|
|
mock_scores.return_value = [
|
|
TrendingStock(
|
|
ticker="AAPL",
|
|
company_name="Apple Inc.",
|
|
score=10.0,
|
|
mention_count=5,
|
|
sentiment=0.5,
|
|
sector=Sector.TECHNOLOGY,
|
|
event_type=EventCategory.EARNINGS,
|
|
news_summary="Apple earnings",
|
|
source_articles=[],
|
|
),
|
|
]
|
|
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
|
graph = TradingAgentsGraph()
|
|
graph.config = {
|
|
"discovery_timeout": 60,
|
|
"discovery_hard_timeout": 120,
|
|
"discovery_cache_ttl": 300,
|
|
"discovery_max_results": 20,
|
|
"discovery_min_mentions": 2,
|
|
}
|
|
|
|
request = DiscoveryRequest(
|
|
lookback_period="24h",
|
|
event_filter=[EventCategory.MERGER_ACQUISITION],
|
|
)
|
|
result = graph.discover_trending(request)
|
|
|
|
assert result.status == DiscoveryStatus.COMPLETED
|
|
assert len(result.trending_stocks) == 0
|
|
|
|
|
|
class TestMultipleSectorsAndEventsFiltering:
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.extract_entities")
|
|
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
|
def test_combined_sector_and_event_filtering(
|
|
self, mock_scores, mock_extract, mock_bulk_news
|
|
):
|
|
mock_bulk_news.return_value = []
|
|
mock_extract.return_value = []
|
|
mock_scores.return_value = [
|
|
TrendingStock(
|
|
ticker="AAPL",
|
|
company_name="Apple Inc.",
|
|
score=10.0,
|
|
mention_count=5,
|
|
sentiment=0.5,
|
|
sector=Sector.TECHNOLOGY,
|
|
event_type=EventCategory.EARNINGS,
|
|
news_summary="Apple earnings",
|
|
source_articles=[],
|
|
),
|
|
TrendingStock(
|
|
ticker="JPM",
|
|
company_name="JPMorgan Chase",
|
|
score=9.0,
|
|
mention_count=4,
|
|
sentiment=0.4,
|
|
sector=Sector.FINANCE,
|
|
event_type=EventCategory.EARNINGS,
|
|
news_summary="JPM earnings",
|
|
source_articles=[],
|
|
),
|
|
TrendingStock(
|
|
ticker="XOM",
|
|
company_name="Exxon Mobil",
|
|
score=8.0,
|
|
mention_count=3,
|
|
sentiment=0.3,
|
|
sector=Sector.ENERGY,
|
|
event_type=EventCategory.REGULATORY,
|
|
news_summary="XOM regulatory news",
|
|
source_articles=[],
|
|
),
|
|
]
|
|
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
|
graph = TradingAgentsGraph()
|
|
graph.config = {
|
|
"discovery_timeout": 60,
|
|
"discovery_hard_timeout": 120,
|
|
"discovery_cache_ttl": 300,
|
|
"discovery_max_results": 20,
|
|
"discovery_min_mentions": 2,
|
|
}
|
|
|
|
request = DiscoveryRequest(
|
|
lookback_period="24h",
|
|
sector_filter=[Sector.TECHNOLOGY, Sector.FINANCE],
|
|
event_filter=[EventCategory.EARNINGS],
|
|
)
|
|
result = graph.discover_trending(request)
|
|
|
|
assert result.status == DiscoveryStatus.COMPLETED
|
|
assert len(result.trending_stocks) == 2
|
|
tickers = [s.ticker for s in result.trending_stocks]
|
|
assert "AAPL" in tickers
|
|
assert "JPM" in tickers
|
|
assert "XOM" not in tickers
|
|
|
|
|
|
class TestDiscoveryResultPersistenceIntegration:
|
|
def test_discovery_result_can_be_serialized_and_saved(self):
|
|
from tradingagents.agents.discovery.persistence import (
|
|
save_discovery_result,
|
|
generate_markdown_summary,
|
|
)
|
|
import tempfile
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
article = NewsArticle(
|
|
title="Test article",
|
|
source="Test",
|
|
url="https://test.com",
|
|
published_at=datetime.now(),
|
|
content_snippet="Test content",
|
|
ticker_mentions=["TEST"],
|
|
)
|
|
|
|
stock = TrendingStock(
|
|
ticker="TEST",
|
|
company_name="Test Company",
|
|
score=5.0,
|
|
mention_count=2,
|
|
sentiment=0.5,
|
|
sector=Sector.TECHNOLOGY,
|
|
event_type=EventCategory.OTHER,
|
|
news_summary="Test news summary",
|
|
source_articles=[article],
|
|
)
|
|
|
|
request = DiscoveryRequest(
|
|
lookback_period="24h",
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
result = DiscoveryResult(
|
|
request=request,
|
|
trending_stocks=[stock],
|
|
status=DiscoveryStatus.COMPLETED,
|
|
started_at=datetime.now(),
|
|
completed_at=datetime.now(),
|
|
)
|
|
|
|
temp_dir = tempfile.mkdtemp()
|
|
try:
|
|
path = save_discovery_result(result, base_path=Path(temp_dir))
|
|
assert path.exists()
|
|
assert (path / "discovery_result.json").exists()
|
|
assert (path / "discovery_summary.md").exists()
|
|
|
|
markdown = generate_markdown_summary(result)
|
|
assert "TEST" in markdown
|
|
assert "Test Company" in markdown
|
|
finally:
|
|
shutil.rmtree(temp_dir)
|