204 lines
6.9 KiB
Python
204 lines
6.9 KiB
Python
from datetime import datetime
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
|
|
from tradingagents.agents.discovery import (
|
|
DiscoveryRequest,
|
|
DiscoveryResult,
|
|
DiscoveryStatus,
|
|
DiscoveryTimeoutError,
|
|
EventCategory,
|
|
NewsArticle,
|
|
Sector,
|
|
TrendingStock,
|
|
)
|
|
|
|
|
|
def create_mock_trending_stock(
|
|
ticker: str = "AAPL",
|
|
company_name: str = "Apple Inc.",
|
|
score: float = 10.0,
|
|
sector: Sector = Sector.TECHNOLOGY,
|
|
event_type: EventCategory = EventCategory.EARNINGS,
|
|
) -> TrendingStock:
|
|
return TrendingStock(
|
|
ticker=ticker,
|
|
company_name=company_name,
|
|
score=score,
|
|
mention_count=5,
|
|
sentiment=0.5,
|
|
sector=sector,
|
|
event_type=event_type,
|
|
news_summary="Test news summary",
|
|
source_articles=[],
|
|
)
|
|
|
|
|
|
def create_mock_news_article() -> NewsArticle:
|
|
return NewsArticle(
|
|
title="Test Article",
|
|
source="Test Source",
|
|
url="https://example.com/article",
|
|
published_at=datetime.now(),
|
|
content_snippet="Test content about Apple stock",
|
|
ticker_mentions=["AAPL"],
|
|
)
|
|
|
|
|
|
class TestDiscoverTrendingReturnsDiscoveryResult:
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.extract_entities")
|
|
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
|
def test_discover_trending_returns_discovery_result(
|
|
self, mock_scores, mock_extract, mock_bulk_news
|
|
):
|
|
mock_bulk_news.return_value = [create_mock_news_article()]
|
|
mock_extract.return_value = []
|
|
mock_scores.return_value = [create_mock_trending_stock()]
|
|
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
|
graph = TradingAgentsGraph()
|
|
graph.config = {
|
|
"discovery_timeout": 60,
|
|
"discovery_hard_timeout": 120,
|
|
"discovery_cache_ttl": 300,
|
|
"discovery_max_results": 20,
|
|
"discovery_min_mentions": 2,
|
|
}
|
|
|
|
result = graph.discover_trending()
|
|
|
|
assert isinstance(result, DiscoveryResult)
|
|
assert result.status == DiscoveryStatus.COMPLETED
|
|
assert len(result.trending_stocks) > 0
|
|
|
|
|
|
class TestAnalyzeTrendingCallsPropagate:
|
|
def test_analyze_trending_calls_propagate(self):
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
|
graph = TradingAgentsGraph()
|
|
graph.propagate = Mock(return_value=({"final_state": "test"}, "BUY"))
|
|
|
|
trending_stock = create_mock_trending_stock()
|
|
|
|
result = graph.analyze_trending(trending_stock)
|
|
|
|
graph.propagate.assert_called_once()
|
|
call_args = graph.propagate.call_args
|
|
assert call_args[0][0] == "AAPL"
|
|
|
|
|
|
class TestSectorFilterParameter:
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.extract_entities")
|
|
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
|
def test_sector_filter_filters_results(
|
|
self, mock_scores, mock_extract, mock_bulk_news
|
|
):
|
|
mock_bulk_news.return_value = [create_mock_news_article()]
|
|
mock_extract.return_value = []
|
|
mock_scores.return_value = [
|
|
create_mock_trending_stock(ticker="AAPL", sector=Sector.TECHNOLOGY),
|
|
create_mock_trending_stock(ticker="JPM", sector=Sector.FINANCE),
|
|
create_mock_trending_stock(ticker="XOM", sector=Sector.ENERGY),
|
|
]
|
|
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
|
graph = TradingAgentsGraph()
|
|
graph.config = {
|
|
"discovery_timeout": 60,
|
|
"discovery_hard_timeout": 120,
|
|
"discovery_cache_ttl": 300,
|
|
"discovery_max_results": 20,
|
|
"discovery_min_mentions": 2,
|
|
}
|
|
|
|
request = DiscoveryRequest(
|
|
lookback_period="24h",
|
|
sector_filter=[Sector.TECHNOLOGY],
|
|
)
|
|
result = graph.discover_trending(request)
|
|
|
|
assert all(
|
|
stock.sector == Sector.TECHNOLOGY for stock in result.trending_stocks
|
|
)
|
|
|
|
|
|
class TestEventFilterParameter:
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.extract_entities")
|
|
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
|
def test_event_filter_filters_results(
|
|
self, mock_scores, mock_extract, mock_bulk_news
|
|
):
|
|
mock_bulk_news.return_value = [create_mock_news_article()]
|
|
mock_extract.return_value = []
|
|
mock_scores.return_value = [
|
|
create_mock_trending_stock(
|
|
ticker="AAPL", event_type=EventCategory.EARNINGS
|
|
),
|
|
create_mock_trending_stock(
|
|
ticker="MSFT", event_type=EventCategory.PRODUCT_LAUNCH
|
|
),
|
|
create_mock_trending_stock(
|
|
ticker="GOOGL", event_type=EventCategory.MERGER_ACQUISITION
|
|
),
|
|
]
|
|
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
|
graph = TradingAgentsGraph()
|
|
graph.config = {
|
|
"discovery_timeout": 60,
|
|
"discovery_hard_timeout": 120,
|
|
"discovery_cache_ttl": 300,
|
|
"discovery_max_results": 20,
|
|
"discovery_min_mentions": 2,
|
|
}
|
|
|
|
request = DiscoveryRequest(
|
|
lookback_period="24h",
|
|
event_filter=[EventCategory.EARNINGS],
|
|
)
|
|
result = graph.discover_trending(request)
|
|
|
|
assert all(
|
|
stock.event_type == EventCategory.EARNINGS
|
|
for stock in result.trending_stocks
|
|
)
|
|
|
|
|
|
class TestTimeoutHandling:
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
def test_timeout_raises_discovery_timeout_error(self, mock_bulk_news):
|
|
def slow_fetch(*args, **kwargs):
|
|
import time
|
|
|
|
time.sleep(0.5)
|
|
return []
|
|
|
|
mock_bulk_news.side_effect = slow_fetch
|
|
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
|
graph = TradingAgentsGraph()
|
|
graph.config = {
|
|
"discovery_timeout": 60,
|
|
"discovery_hard_timeout": 0.1,
|
|
"discovery_cache_ttl": 300,
|
|
"discovery_max_results": 20,
|
|
"discovery_min_mentions": 2,
|
|
}
|
|
|
|
with pytest.raises(DiscoveryTimeoutError):
|
|
graph.discover_trending()
|