248 lines
8.7 KiB
Python
248 lines
8.7 KiB
Python
import pytest
|
|
from unittest.mock import Mock, patch
|
|
from datetime import datetime, timedelta
|
|
from tradingagents.dataflows.google import (
|
|
get_google_news,
|
|
get_bulk_news_google,
|
|
)
|
|
|
|
|
|
class TestGetGoogleNews:
|
|
"""Test suite for get_google_news function."""
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_google_news_basic(self, mock_get_news_data):
|
|
"""Test basic Google News retrieval."""
|
|
mock_get_news_data.return_value = []
|
|
|
|
query = "AAPL stock"
|
|
curr_date = "2024-01-15"
|
|
look_back_days = 7
|
|
|
|
result = get_google_news(query, curr_date, look_back_days)
|
|
|
|
assert isinstance(result, str)
|
|
mock_get_news_data.assert_called_once()
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_google_news_query_formatting(self, mock_get_news_data):
|
|
"""Test that query spaces are replaced with plus signs."""
|
|
mock_get_news_data.return_value = []
|
|
|
|
query = "Apple Inc stock news"
|
|
curr_date = "2024-01-15"
|
|
look_back_days = 7
|
|
|
|
result = get_google_news(query, curr_date, look_back_days)
|
|
|
|
# Query should be formatted with + instead of spaces
|
|
call_args = mock_get_news_data.call_args[0]
|
|
assert "+" in call_args[0] or call_args[0] == query.replace(" ", "+")
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_google_news_with_results(self, mock_get_news_data):
|
|
"""Test formatting of news results."""
|
|
mock_news = [
|
|
{
|
|
"title": "Apple stock rises",
|
|
"source": "Bloomberg",
|
|
"snippet": "Apple Inc. shares rose 5% today...",
|
|
},
|
|
{
|
|
"title": "New iPhone release",
|
|
"source": "Reuters",
|
|
"snippet": "Apple announces new iPhone model...",
|
|
},
|
|
]
|
|
|
|
mock_get_news_data.return_value = mock_news
|
|
|
|
query = "AAPL"
|
|
curr_date = "2024-01-15"
|
|
look_back_days = 7
|
|
|
|
result = get_google_news(query, curr_date, look_back_days)
|
|
|
|
assert "Apple stock rises" in result
|
|
assert "New iPhone release" in result
|
|
assert "Bloomberg" in result
|
|
assert "Reuters" in result
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_google_news_empty_results(self, mock_get_news_data):
|
|
"""Test handling of empty news results."""
|
|
mock_get_news_data.return_value = []
|
|
|
|
query = "NonexistentTicker"
|
|
curr_date = "2024-01-15"
|
|
look_back_days = 7
|
|
|
|
result = get_google_news(query, curr_date, look_back_days)
|
|
|
|
assert result == ""
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_google_news_date_calculation(self, mock_get_news_data):
|
|
"""Test that lookback date is calculated correctly."""
|
|
mock_get_news_data.return_value = []
|
|
|
|
query = "TSLA"
|
|
curr_date = "2024-01-15"
|
|
look_back_days = 30
|
|
|
|
result = get_google_news(query, curr_date, look_back_days)
|
|
|
|
# Verify date calculation by checking call arguments
|
|
call_args = mock_get_news_data.call_args[0]
|
|
before_date = call_args[1]
|
|
end_date = call_args[2]
|
|
|
|
assert end_date == curr_date
|
|
|
|
|
|
class TestGetBulkNewsGoogle:
|
|
"""Test suite for get_bulk_news_google function."""
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_bulk_news_google_basic(self, mock_get_news_data):
|
|
"""Test basic bulk news retrieval."""
|
|
mock_get_news_data.return_value = []
|
|
|
|
result = get_bulk_news_google(24)
|
|
|
|
assert isinstance(result, list)
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_bulk_news_google_multiple_queries(self, mock_get_news_data):
|
|
"""Test that multiple search queries are executed."""
|
|
mock_get_news_data.return_value = []
|
|
|
|
result = get_bulk_news_google(24)
|
|
|
|
# Should call getNewsData multiple times for different queries
|
|
assert mock_get_news_data.call_count >= 3
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_bulk_news_google_with_articles(self, mock_get_news_data):
|
|
"""Test article parsing and deduplication."""
|
|
mock_articles = [
|
|
{
|
|
"title": "Market update",
|
|
"source": "Financial Times",
|
|
"snippet": "Markets closed higher today...",
|
|
"link": "https://example.com/1",
|
|
"date": "2024-01-15",
|
|
},
|
|
{
|
|
"title": "Trading news",
|
|
"source": "WSJ",
|
|
"snippet": "Trading volume increased...",
|
|
"link": "https://example.com/2",
|
|
"date": "2024-01-15",
|
|
},
|
|
]
|
|
|
|
mock_get_news_data.return_value = mock_articles
|
|
|
|
result = get_bulk_news_google(24)
|
|
|
|
assert len(result) > 0
|
|
assert all("title" in article for article in result)
|
|
assert all("source" in article for article in result)
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_bulk_news_google_deduplication(self, mock_get_news_data):
|
|
"""Test that duplicate articles are removed."""
|
|
duplicate_article = {
|
|
"title": "Same article",
|
|
"source": "Source",
|
|
"snippet": "Content",
|
|
"link": "https://example.com",
|
|
"date": "2024-01-15",
|
|
}
|
|
|
|
# Return same article multiple times
|
|
mock_get_news_data.return_value = [duplicate_article, duplicate_article]
|
|
|
|
result = get_bulk_news_google(24)
|
|
|
|
# Should only appear once
|
|
titles = [article["title"] for article in result]
|
|
assert titles.count("Same article") <= 1
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_bulk_news_google_content_truncation(self, mock_get_news_data):
|
|
"""Test that content snippets are truncated to 500 characters."""
|
|
long_snippet = "A" * 1000
|
|
|
|
mock_articles = [
|
|
{
|
|
"title": "Article",
|
|
"source": "Source",
|
|
"snippet": long_snippet,
|
|
"link": "https://example.com",
|
|
"date": "2024-01-15",
|
|
}
|
|
]
|
|
|
|
mock_get_news_data.return_value = mock_articles
|
|
|
|
result = get_bulk_news_google(24)
|
|
|
|
if len(result) > 0:
|
|
assert len(result[0]["content_snippet"]) <= 500
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_bulk_news_google_error_handling(self, mock_get_news_data):
|
|
"""Test error handling when getNewsData raises exception."""
|
|
mock_get_news_data.side_effect = Exception("API Error")
|
|
|
|
result = get_bulk_news_google(24)
|
|
|
|
# Should return empty list or partial results
|
|
assert isinstance(result, list)
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_bulk_news_google_lookback_periods(self, mock_get_news_data):
|
|
"""Test with various lookback periods."""
|
|
mock_get_news_data.return_value = []
|
|
|
|
lookback_hours = [1, 6, 12, 24, 48, 168]
|
|
|
|
for hours in lookback_hours:
|
|
result = get_bulk_news_google(hours)
|
|
assert isinstance(result, list)
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_bulk_news_google_date_formatting(self, mock_get_news_data):
|
|
"""Test that dates are formatted correctly for API."""
|
|
mock_get_news_data.return_value = []
|
|
|
|
result = get_bulk_news_google(24)
|
|
|
|
# Check that dates in YYYY-MM-DD format are used
|
|
for call in mock_get_news_data.call_args_list:
|
|
start_date = call[0][1]
|
|
end_date = call[0][2]
|
|
|
|
# Both should be in YYYY-MM-DD format
|
|
assert len(start_date) == 10
|
|
assert len(end_date) == 10
|
|
assert start_date.count("-") == 2
|
|
assert end_date.count("-") == 2
|
|
|
|
@patch('tradingagents.dataflows.google.getNewsData')
|
|
def test_get_bulk_news_google_missing_fields(self, mock_get_news_data):
|
|
"""Test handling of articles with missing fields."""
|
|
incomplete_articles = [
|
|
{"title": "Title only"},
|
|
{"source": "Source only"},
|
|
{"title": "Complete", "source": "Source", "snippet": "Text", "link": "url", "date": "2024-01-15"},
|
|
]
|
|
|
|
mock_get_news_data.return_value = incomplete_articles
|
|
|
|
result = get_bulk_news_google(24)
|
|
|
|
# Should handle missing fields gracefully
|
|
assert isinstance(result, list) |