349 lines
12 KiB
Python
349 lines
12 KiB
Python
import pytest
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
from datetime import datetime, timedelta
|
|
import requests
|
|
from tradingagents.dataflows.brave import (
|
|
get_api_key,
|
|
get_bulk_news_brave,
|
|
_parse_brave_age,
|
|
_make_request_with_retry,
|
|
BRAVE_SEARCH_URL,
|
|
DEFAULT_TIMEOUT,
|
|
MAX_RETRIES,
|
|
)
|
|
|
|
|
|
class TestGetApiKey:
|
|
|
|
def test_get_api_key_success(self):
|
|
with patch.dict('os.environ', {'BRAVE_API_KEY': 'test_key_123'}):
|
|
result = get_api_key()
|
|
assert result == 'test_key_123'
|
|
|
|
def test_get_api_key_missing(self):
|
|
with patch.dict('os.environ', {}, clear=True):
|
|
with pytest.raises(ValueError, match="BRAVE_API_KEY environment variable is not set"):
|
|
get_api_key()
|
|
|
|
|
|
class TestParseBraveAge:
|
|
|
|
def test_parse_hours_ago(self):
|
|
result = _parse_brave_age("2 hours ago")
|
|
expected = datetime.now() - timedelta(hours=2)
|
|
assert abs((result - expected).total_seconds()) < 2
|
|
|
|
def test_parse_single_hour(self):
|
|
result = _parse_brave_age("1 hour ago")
|
|
expected = datetime.now() - timedelta(hours=1)
|
|
assert abs((result - expected).total_seconds()) < 2
|
|
|
|
def test_parse_days_ago(self):
|
|
result = _parse_brave_age("3 days ago")
|
|
expected = datetime.now() - timedelta(days=3)
|
|
assert abs((result - expected).total_seconds()) < 2
|
|
|
|
def test_parse_weeks_ago(self):
|
|
result = _parse_brave_age("2 weeks ago")
|
|
expected = datetime.now() - timedelta(weeks=2)
|
|
assert abs((result - expected).total_seconds()) < 2
|
|
|
|
def test_parse_minutes_ago(self):
|
|
result = _parse_brave_age("30 minutes ago")
|
|
expected = datetime.now() - timedelta(minutes=30)
|
|
assert abs((result - expected).total_seconds()) < 2
|
|
|
|
def test_parse_empty_string(self):
|
|
result = _parse_brave_age("")
|
|
expected = datetime.now()
|
|
assert abs((result - expected).total_seconds()) < 2
|
|
|
|
def test_parse_invalid_format(self):
|
|
result = _parse_brave_age("invalid format")
|
|
expected = datetime.now()
|
|
assert abs((result - expected).total_seconds()) < 2
|
|
|
|
def test_parse_uppercase(self):
|
|
result = _parse_brave_age("5 HOURS AGO")
|
|
expected = datetime.now() - timedelta(hours=5)
|
|
assert abs((result - expected).total_seconds()) < 2
|
|
|
|
|
|
class TestMakeRequestWithRetry:
|
|
|
|
@patch('tradingagents.dataflows.brave.requests.get')
|
|
def test_successful_request(self, mock_get):
|
|
mock_response = Mock()
|
|
mock_response.status_code = 200
|
|
mock_response.raise_for_status = Mock()
|
|
mock_get.return_value = mock_response
|
|
|
|
result = _make_request_with_retry("http://test.com", {}, {})
|
|
|
|
assert result == mock_response
|
|
mock_get.assert_called_once()
|
|
|
|
@patch('tradingagents.dataflows.brave.requests.get')
|
|
@patch('tradingagents.dataflows.brave.time.sleep')
|
|
def test_retry_on_timeout(self, mock_sleep, mock_get):
|
|
mock_get.side_effect = [
|
|
requests.exceptions.Timeout(),
|
|
requests.exceptions.Timeout(),
|
|
Mock(status_code=200, raise_for_status=Mock()),
|
|
]
|
|
|
|
result = _make_request_with_retry("http://test.com", {}, {})
|
|
|
|
assert mock_get.call_count == 3
|
|
assert mock_sleep.call_count == 2
|
|
|
|
@patch('tradingagents.dataflows.brave.requests.get')
|
|
@patch('tradingagents.dataflows.brave.time.sleep')
|
|
def test_retry_on_connection_error(self, mock_sleep, mock_get):
|
|
mock_get.side_effect = [
|
|
requests.exceptions.ConnectionError(),
|
|
Mock(status_code=200, raise_for_status=Mock()),
|
|
]
|
|
|
|
result = _make_request_with_retry("http://test.com", {}, {})
|
|
|
|
assert mock_get.call_count == 2
|
|
assert mock_sleep.call_count == 1
|
|
|
|
@patch('tradingagents.dataflows.brave.requests.get')
|
|
@patch('tradingagents.dataflows.brave.time.sleep')
|
|
def test_retry_on_rate_limit(self, mock_sleep, mock_get):
|
|
rate_limited_response = Mock()
|
|
rate_limited_response.status_code = 429
|
|
rate_limited_response.headers = {"Retry-After": "1"}
|
|
|
|
success_response = Mock()
|
|
success_response.status_code = 200
|
|
success_response.raise_for_status = Mock()
|
|
|
|
mock_get.side_effect = [rate_limited_response, success_response]
|
|
|
|
result = _make_request_with_retry("http://test.com", {}, {})
|
|
|
|
assert mock_get.call_count == 2
|
|
assert mock_sleep.call_count == 1
|
|
|
|
@patch('tradingagents.dataflows.brave.requests.get')
|
|
@patch('tradingagents.dataflows.brave.time.sleep')
|
|
def test_max_retries_exceeded(self, mock_sleep, mock_get):
|
|
mock_get.side_effect = requests.exceptions.Timeout()
|
|
|
|
with pytest.raises(requests.exceptions.Timeout):
|
|
_make_request_with_retry("http://test.com", {}, {}, max_retries=3)
|
|
|
|
assert mock_get.call_count == 3
|
|
|
|
@patch('tradingagents.dataflows.brave.requests.get')
|
|
def test_non_retryable_http_error(self, mock_get):
|
|
mock_response = Mock()
|
|
mock_response.status_code = 400
|
|
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=mock_response)
|
|
mock_get.return_value = mock_response
|
|
|
|
with pytest.raises(requests.exceptions.HTTPError):
|
|
_make_request_with_retry("http://test.com", {}, {})
|
|
|
|
assert mock_get.call_count == 1
|
|
|
|
|
|
class TestGetBulkNewsBrave:
|
|
|
|
@patch('tradingagents.dataflows.brave.get_api_key')
|
|
def test_returns_empty_when_no_api_key(self, mock_get_api_key):
|
|
mock_get_api_key.side_effect = ValueError("BRAVE_API_KEY not set")
|
|
|
|
result = get_bulk_news_brave(24)
|
|
|
|
assert result == []
|
|
|
|
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
|
@patch('tradingagents.dataflows.brave.get_api_key')
|
|
def test_basic_call(self, mock_get_api_key, mock_request):
|
|
mock_get_api_key.return_value = "test_key"
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"results": []}
|
|
mock_request.return_value = mock_response
|
|
|
|
result = get_bulk_news_brave(24)
|
|
|
|
assert isinstance(result, list)
|
|
assert mock_request.call_count == 5
|
|
|
|
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
|
@patch('tradingagents.dataflows.brave.get_api_key')
|
|
def test_parses_articles(self, mock_get_api_key, mock_request):
|
|
mock_get_api_key.return_value = "test_key"
|
|
|
|
mock_article = {
|
|
"title": "Test Stock News",
|
|
"meta_url": {"netloc": "reuters.com"},
|
|
"url": "https://reuters.com/article1",
|
|
"age": "2 hours ago",
|
|
"description": "This is a test article about stocks.",
|
|
}
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"results": [mock_article]}
|
|
mock_request.return_value = mock_response
|
|
|
|
result = get_bulk_news_brave(24)
|
|
|
|
assert len(result) >= 1
|
|
article = result[0]
|
|
assert article["title"] == "Test Stock News"
|
|
assert article["source"] == "reuters.com"
|
|
assert article["url"] == "https://reuters.com/article1"
|
|
assert "published_at" in article
|
|
assert article["content_snippet"] == "This is a test article about stocks."
|
|
|
|
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
|
@patch('tradingagents.dataflows.brave.get_api_key')
|
|
def test_deduplicates_by_url(self, mock_get_api_key, mock_request):
|
|
mock_get_api_key.return_value = "test_key"
|
|
|
|
duplicate_article = {
|
|
"title": "Duplicate Article",
|
|
"meta_url": {"netloc": "news.com"},
|
|
"url": "https://news.com/same-url",
|
|
"age": "1 hour ago",
|
|
"description": "Duplicate content.",
|
|
}
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"results": [duplicate_article, duplicate_article]}
|
|
mock_request.return_value = mock_response
|
|
|
|
result = get_bulk_news_brave(24)
|
|
|
|
urls = [a["url"] for a in result]
|
|
assert len(urls) == len(set(urls))
|
|
|
|
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
|
@patch('tradingagents.dataflows.brave.get_api_key')
|
|
def test_truncates_long_descriptions(self, mock_get_api_key, mock_request):
|
|
mock_get_api_key.return_value = "test_key"
|
|
|
|
long_description = "A" * 1000
|
|
|
|
mock_article = {
|
|
"title": "Long Article",
|
|
"meta_url": {"netloc": "news.com"},
|
|
"url": "https://news.com/article",
|
|
"age": "1 hour ago",
|
|
"description": long_description,
|
|
}
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"results": [mock_article]}
|
|
mock_request.return_value = mock_response
|
|
|
|
result = get_bulk_news_brave(24)
|
|
|
|
assert len(result[0]["content_snippet"]) == 500
|
|
|
|
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
|
@patch('tradingagents.dataflows.brave.get_api_key')
|
|
def test_freshness_parameter_24h(self, mock_get_api_key, mock_request):
|
|
mock_get_api_key.return_value = "test_key"
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"results": []}
|
|
mock_request.return_value = mock_response
|
|
|
|
get_bulk_news_brave(24)
|
|
|
|
call_args = mock_request.call_args_list[0]
|
|
params = call_args[0][2]
|
|
assert params["freshness"] == "pd"
|
|
|
|
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
|
@patch('tradingagents.dataflows.brave.get_api_key')
|
|
def test_freshness_parameter_7d(self, mock_get_api_key, mock_request):
|
|
mock_get_api_key.return_value = "test_key"
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"results": []}
|
|
mock_request.return_value = mock_response
|
|
|
|
get_bulk_news_brave(168)
|
|
|
|
call_args = mock_request.call_args_list[0]
|
|
params = call_args[0][2]
|
|
assert params["freshness"] == "pw"
|
|
|
|
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
|
@patch('tradingagents.dataflows.brave.get_api_key')
|
|
def test_freshness_parameter_month(self, mock_get_api_key, mock_request):
|
|
mock_get_api_key.return_value = "test_key"
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"results": []}
|
|
mock_request.return_value = mock_response
|
|
|
|
get_bulk_news_brave(720)
|
|
|
|
call_args = mock_request.call_args_list[0]
|
|
params = call_args[0][2]
|
|
assert params["freshness"] == "pm"
|
|
|
|
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
|
@patch('tradingagents.dataflows.brave.get_api_key')
|
|
def test_handles_missing_meta_url(self, mock_get_api_key, mock_request):
|
|
mock_get_api_key.return_value = "test_key"
|
|
|
|
mock_article = {
|
|
"title": "Article Without Meta URL",
|
|
"url": "https://news.com/article",
|
|
"age": "1 hour ago",
|
|
"description": "Content",
|
|
}
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"results": [mock_article]}
|
|
mock_request.return_value = mock_response
|
|
|
|
result = get_bulk_news_brave(24)
|
|
|
|
assert result[0]["source"] == "Brave News"
|
|
|
|
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
|
@patch('tradingagents.dataflows.brave.get_api_key')
|
|
def test_continues_on_query_failure(self, mock_get_api_key, mock_request):
|
|
mock_get_api_key.return_value = "test_key"
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"results": [{"title": "Article", "url": "https://test.com", "age": "1h", "description": "test"}]}
|
|
|
|
mock_request.side_effect = [
|
|
requests.exceptions.HTTPError("Error"),
|
|
mock_response,
|
|
mock_response,
|
|
mock_response,
|
|
mock_response,
|
|
]
|
|
|
|
result = get_bulk_news_brave(24)
|
|
|
|
assert len(result) > 0
|
|
|
|
@patch('tradingagents.dataflows.brave._make_request_with_retry')
|
|
@patch('tradingagents.dataflows.brave.get_api_key')
|
|
def test_skips_articles_without_url(self, mock_get_api_key, mock_request):
|
|
mock_get_api_key.return_value = "test_key"
|
|
|
|
mock_articles = [
|
|
{"title": "No URL Article", "age": "1h", "description": "test"},
|
|
{"title": "Has URL", "url": "https://test.com", "age": "1h", "description": "test"},
|
|
]
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"results": mock_articles}
|
|
mock_request.return_value = mock_response
|
|
|
|
result = get_bulk_news_brave(24)
|
|
|
|
urls = [a["url"] for a in result if a.get("url")]
|
|
assert all(url for url in urls)
|