From 24e955d29b4da7eaa381d0e3d913d6b1b6868dae Mon Sep 17 00:00:00 2001 From: Joseph O'Brien <98370624+89jobrien@users.noreply.github.com> Date: Wed, 3 Dec 2025 00:43:27 -0500 Subject: [PATCH] Add Brave and Tavily news data vendors with retry handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add Brave Search API integration for bulk news fetching - Add Tavily API integration for bulk news fetching - Implement timeout and retry logic with exponential backoff for both vendors - Make bulk news vendor order configurable via default_config.py - Add tavily-python to requirements.txt - Document BRAVE_API_KEY and TAVILY_API_KEY in .env.example and README - Add comprehensive unit tests for both vendors (49 tests) The news discovery system now uses fallback chain: Tavily → Brave → Alpha Vantage → OpenAI → Google 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .env.example | 4 +- README.md | 17 +- requirements.txt | 1 + tests/dataflows/test_brave.py | 348 +++++++++++++++++++++++++ tests/dataflows/test_tavily.py | 370 +++++++++++++++++++++++++++ tradingagents/dataflows/brave.py | 150 +++++++++++ tradingagents/dataflows/interface.py | 7 +- tradingagents/dataflows/tavily.py | 128 +++++++++ tradingagents/default_config.py | 3 + 9 files changed, 1024 insertions(+), 4 deletions(-) create mode 100644 tests/dataflows/test_brave.py create mode 100644 tests/dataflows/test_tavily.py create mode 100644 tradingagents/dataflows/brave.py create mode 100644 tradingagents/dataflows/tavily.py diff --git a/.env.example b/.env.example index 1e257c3c..2fb8acc8 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,4 @@ ALPHA_VANTAGE_API_KEY=alpha_vantage_api_key_placeholder -OPENAI_API_KEY=openai_api_key_placeholder \ No newline at end of file +OPENAI_API_KEY=openai_api_key_placeholder +BRAVE_API_KEY=brave_api_key_placeholder +TAVILY_API_KEY=tavily_api_key_placeholder \ No newline at end of file diff --git a/README.md b/README.md index e197ae31..23492cff 100644 --- a/README.md +++ b/README.md @@ -73,13 +73,25 @@ source .venv/bin/activate ### Required API Keys -The framework requires an OpenAI API key for powering the agents and an Alpha Vantage API key for fundamental and news data (default configuration). +The framework requires an OpenAI API key for powering the agents and at least one news data provider API key. + +**Required:** +- `OPENAI_API_KEY` - Powers the LLM agents + +**News Data Providers (at least one required):** +- `TAVILY_API_KEY` - Tavily search API (preferred for news discovery) +- `BRAVE_API_KEY` - Brave Search API (fallback option) +- `ALPHA_VANTAGE_API_KEY` - Alpha Vantage API (for fundamentals and news) + +The news discovery system uses a fallback chain: Tavily → Brave → Alpha Vantage → OpenAI → Google. Configure the API keys for your preferred providers. Set environment variables: ```bash export OPENAI_API_KEY=your_openai_api_key export ALPHA_VANTAGE_API_KEY=your_alpha_vantage_api_key +export TAVILY_API_KEY=your_tavily_api_key +export BRAVE_API_KEY=your_brave_api_key ``` Alternatively, create a `.env` file in the project root: @@ -95,7 +107,7 @@ Then edit the `.env` file with your API keys. Run the CLI: ```bash -uv run cli/main.py +uv run python -m cli.main ``` The CLI provides two main modes: @@ -105,6 +117,7 @@ The CLI provides two main modes: The trending stock discovery feature uses LLM-powered entity extraction to identify stocks making news. This is the primary enhancement in this fork, enabling proactive discovery of trading opportunities. Configuration options: + - **Lookback period:** 1h, 6h, 24h, or 7d - **Sector filter:** Technology, Healthcare, Finance, Energy, Consumer Goods, Industrials - **Event type filter:** Earnings, Merger/Acquisition, Regulatory, Product Launch, Executive Change diff --git a/requirements.txt b/requirements.txt index a6154cd2..c828690d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,3 +24,4 @@ rich questionary langchain_anthropic langchain-google-genai +tavily-python diff --git a/tests/dataflows/test_brave.py b/tests/dataflows/test_brave.py new file mode 100644 index 00000000..e257301f --- /dev/null +++ b/tests/dataflows/test_brave.py @@ -0,0 +1,348 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from datetime import datetime, timedelta +import requests +from tradingagents.dataflows.brave import ( + get_api_key, + get_bulk_news_brave, + _parse_brave_age, + _make_request_with_retry, + BRAVE_SEARCH_URL, + DEFAULT_TIMEOUT, + MAX_RETRIES, +) + + +class TestGetApiKey: + + def test_get_api_key_success(self): + with patch.dict('os.environ', {'BRAVE_API_KEY': 'test_key_123'}): + result = get_api_key() + assert result == 'test_key_123' + + def test_get_api_key_missing(self): + with patch.dict('os.environ', {}, clear=True): + with pytest.raises(ValueError, match="BRAVE_API_KEY environment variable is not set"): + get_api_key() + + +class TestParseBraveAge: + + def test_parse_hours_ago(self): + result = _parse_brave_age("2 hours ago") + expected = datetime.now() - timedelta(hours=2) + assert abs((result - expected).total_seconds()) < 2 + + def test_parse_single_hour(self): + result = _parse_brave_age("1 hour ago") + expected = datetime.now() - timedelta(hours=1) + assert abs((result - expected).total_seconds()) < 2 + + def test_parse_days_ago(self): + result = _parse_brave_age("3 days ago") + expected = datetime.now() - timedelta(days=3) + assert abs((result - expected).total_seconds()) < 2 + + def test_parse_weeks_ago(self): + result = _parse_brave_age("2 weeks ago") + expected = datetime.now() - timedelta(weeks=2) + assert abs((result - expected).total_seconds()) < 2 + + def test_parse_minutes_ago(self): + result = _parse_brave_age("30 minutes ago") + expected = datetime.now() - timedelta(minutes=30) + assert abs((result - expected).total_seconds()) < 2 + + def test_parse_empty_string(self): + result = _parse_brave_age("") + expected = datetime.now() + assert abs((result - expected).total_seconds()) < 2 + + def test_parse_invalid_format(self): + result = _parse_brave_age("invalid format") + expected = datetime.now() + assert abs((result - expected).total_seconds()) < 2 + + def test_parse_uppercase(self): + result = _parse_brave_age("5 HOURS AGO") + expected = datetime.now() - timedelta(hours=5) + assert abs((result - expected).total_seconds()) < 2 + + +class TestMakeRequestWithRetry: + + @patch('tradingagents.dataflows.brave.requests.get') + def test_successful_request(self, mock_get): + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status = Mock() + mock_get.return_value = mock_response + + result = _make_request_with_retry("http://test.com", {}, {}) + + assert result == mock_response + mock_get.assert_called_once() + + @patch('tradingagents.dataflows.brave.requests.get') + @patch('tradingagents.dataflows.brave.time.sleep') + def test_retry_on_timeout(self, mock_sleep, mock_get): + mock_get.side_effect = [ + requests.exceptions.Timeout(), + requests.exceptions.Timeout(), + Mock(status_code=200, raise_for_status=Mock()), + ] + + result = _make_request_with_retry("http://test.com", {}, {}) + + assert mock_get.call_count == 3 + assert mock_sleep.call_count == 2 + + @patch('tradingagents.dataflows.brave.requests.get') + @patch('tradingagents.dataflows.brave.time.sleep') + def test_retry_on_connection_error(self, mock_sleep, mock_get): + mock_get.side_effect = [ + requests.exceptions.ConnectionError(), + Mock(status_code=200, raise_for_status=Mock()), + ] + + result = _make_request_with_retry("http://test.com", {}, {}) + + assert mock_get.call_count == 2 + assert mock_sleep.call_count == 1 + + @patch('tradingagents.dataflows.brave.requests.get') + @patch('tradingagents.dataflows.brave.time.sleep') + def test_retry_on_rate_limit(self, mock_sleep, mock_get): + rate_limited_response = Mock() + rate_limited_response.status_code = 429 + rate_limited_response.headers = {"Retry-After": "1"} + + success_response = Mock() + success_response.status_code = 200 + success_response.raise_for_status = Mock() + + mock_get.side_effect = [rate_limited_response, success_response] + + result = _make_request_with_retry("http://test.com", {}, {}) + + assert mock_get.call_count == 2 + assert mock_sleep.call_count == 1 + + @patch('tradingagents.dataflows.brave.requests.get') + @patch('tradingagents.dataflows.brave.time.sleep') + def test_max_retries_exceeded(self, mock_sleep, mock_get): + mock_get.side_effect = requests.exceptions.Timeout() + + with pytest.raises(requests.exceptions.Timeout): + _make_request_with_retry("http://test.com", {}, {}, max_retries=3) + + assert mock_get.call_count == 3 + + @patch('tradingagents.dataflows.brave.requests.get') + def test_non_retryable_http_error(self, mock_get): + mock_response = Mock() + mock_response.status_code = 400 + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=mock_response) + mock_get.return_value = mock_response + + with pytest.raises(requests.exceptions.HTTPError): + _make_request_with_retry("http://test.com", {}, {}) + + assert mock_get.call_count == 1 + + +class TestGetBulkNewsBrave: + + @patch('tradingagents.dataflows.brave.get_api_key') + def test_returns_empty_when_no_api_key(self, mock_get_api_key): + mock_get_api_key.side_effect = ValueError("BRAVE_API_KEY not set") + + result = get_bulk_news_brave(24) + + assert result == [] + + @patch('tradingagents.dataflows.brave._make_request_with_retry') + @patch('tradingagents.dataflows.brave.get_api_key') + def test_basic_call(self, mock_get_api_key, mock_request): + mock_get_api_key.return_value = "test_key" + mock_response = Mock() + mock_response.json.return_value = {"results": []} + mock_request.return_value = mock_response + + result = get_bulk_news_brave(24) + + assert isinstance(result, list) + assert mock_request.call_count == 5 + + @patch('tradingagents.dataflows.brave._make_request_with_retry') + @patch('tradingagents.dataflows.brave.get_api_key') + def test_parses_articles(self, mock_get_api_key, mock_request): + mock_get_api_key.return_value = "test_key" + + mock_article = { + "title": "Test Stock News", + "meta_url": {"netloc": "reuters.com"}, + "url": "https://reuters.com/article1", + "age": "2 hours ago", + "description": "This is a test article about stocks.", + } + + mock_response = Mock() + mock_response.json.return_value = {"results": [mock_article]} + mock_request.return_value = mock_response + + result = get_bulk_news_brave(24) + + assert len(result) >= 1 + article = result[0] + assert article["title"] == "Test Stock News" + assert article["source"] == "reuters.com" + assert article["url"] == "https://reuters.com/article1" + assert "published_at" in article + assert article["content_snippet"] == "This is a test article about stocks." + + @patch('tradingagents.dataflows.brave._make_request_with_retry') + @patch('tradingagents.dataflows.brave.get_api_key') + def test_deduplicates_by_url(self, mock_get_api_key, mock_request): + mock_get_api_key.return_value = "test_key" + + duplicate_article = { + "title": "Duplicate Article", + "meta_url": {"netloc": "news.com"}, + "url": "https://news.com/same-url", + "age": "1 hour ago", + "description": "Duplicate content.", + } + + mock_response = Mock() + mock_response.json.return_value = {"results": [duplicate_article, duplicate_article]} + mock_request.return_value = mock_response + + result = get_bulk_news_brave(24) + + urls = [a["url"] for a in result] + assert len(urls) == len(set(urls)) + + @patch('tradingagents.dataflows.brave._make_request_with_retry') + @patch('tradingagents.dataflows.brave.get_api_key') + def test_truncates_long_descriptions(self, mock_get_api_key, mock_request): + mock_get_api_key.return_value = "test_key" + + long_description = "A" * 1000 + + mock_article = { + "title": "Long Article", + "meta_url": {"netloc": "news.com"}, + "url": "https://news.com/article", + "age": "1 hour ago", + "description": long_description, + } + + mock_response = Mock() + mock_response.json.return_value = {"results": [mock_article]} + mock_request.return_value = mock_response + + result = get_bulk_news_brave(24) + + assert len(result[0]["content_snippet"]) == 500 + + @patch('tradingagents.dataflows.brave._make_request_with_retry') + @patch('tradingagents.dataflows.brave.get_api_key') + def test_freshness_parameter_24h(self, mock_get_api_key, mock_request): + mock_get_api_key.return_value = "test_key" + mock_response = Mock() + mock_response.json.return_value = {"results": []} + mock_request.return_value = mock_response + + get_bulk_news_brave(24) + + call_args = mock_request.call_args_list[0] + params = call_args[0][2] + assert params["freshness"] == "pd" + + @patch('tradingagents.dataflows.brave._make_request_with_retry') + @patch('tradingagents.dataflows.brave.get_api_key') + def test_freshness_parameter_7d(self, mock_get_api_key, mock_request): + mock_get_api_key.return_value = "test_key" + mock_response = Mock() + mock_response.json.return_value = {"results": []} + mock_request.return_value = mock_response + + get_bulk_news_brave(168) + + call_args = mock_request.call_args_list[0] + params = call_args[0][2] + assert params["freshness"] == "pw" + + @patch('tradingagents.dataflows.brave._make_request_with_retry') + @patch('tradingagents.dataflows.brave.get_api_key') + def test_freshness_parameter_month(self, mock_get_api_key, mock_request): + mock_get_api_key.return_value = "test_key" + mock_response = Mock() + mock_response.json.return_value = {"results": []} + mock_request.return_value = mock_response + + get_bulk_news_brave(720) + + call_args = mock_request.call_args_list[0] + params = call_args[0][2] + assert params["freshness"] == "pm" + + @patch('tradingagents.dataflows.brave._make_request_with_retry') + @patch('tradingagents.dataflows.brave.get_api_key') + def test_handles_missing_meta_url(self, mock_get_api_key, mock_request): + mock_get_api_key.return_value = "test_key" + + mock_article = { + "title": "Article Without Meta URL", + "url": "https://news.com/article", + "age": "1 hour ago", + "description": "Content", + } + + mock_response = Mock() + mock_response.json.return_value = {"results": [mock_article]} + mock_request.return_value = mock_response + + result = get_bulk_news_brave(24) + + assert result[0]["source"] == "Brave News" + + @patch('tradingagents.dataflows.brave._make_request_with_retry') + @patch('tradingagents.dataflows.brave.get_api_key') + def test_continues_on_query_failure(self, mock_get_api_key, mock_request): + mock_get_api_key.return_value = "test_key" + + mock_response = Mock() + mock_response.json.return_value = {"results": [{"title": "Article", "url": "https://test.com", "age": "1h", "description": "test"}]} + + mock_request.side_effect = [ + requests.exceptions.HTTPError("Error"), + mock_response, + mock_response, + mock_response, + mock_response, + ] + + result = get_bulk_news_brave(24) + + assert len(result) > 0 + + @patch('tradingagents.dataflows.brave._make_request_with_retry') + @patch('tradingagents.dataflows.brave.get_api_key') + def test_skips_articles_without_url(self, mock_get_api_key, mock_request): + mock_get_api_key.return_value = "test_key" + + mock_articles = [ + {"title": "No URL Article", "age": "1h", "description": "test"}, + {"title": "Has URL", "url": "https://test.com", "age": "1h", "description": "test"}, + ] + + mock_response = Mock() + mock_response.json.return_value = {"results": mock_articles} + mock_request.return_value = mock_response + + result = get_bulk_news_brave(24) + + urls = [a["url"] for a in result if a.get("url")] + assert all(url for url in urls) diff --git a/tests/dataflows/test_tavily.py b/tests/dataflows/test_tavily.py new file mode 100644 index 00000000..cc1b793c --- /dev/null +++ b/tests/dataflows/test_tavily.py @@ -0,0 +1,370 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from datetime import datetime, timedelta +from tradingagents.dataflows.tavily import ( + get_api_key, + get_bulk_news_tavily, + _search_with_retry, + DEFAULT_TIMEOUT, + MAX_RETRIES, +) + + +class TestGetApiKey: + + def test_get_api_key_success(self): + with patch.dict('os.environ', {'TAVILY_API_KEY': 'test_key_123'}): + result = get_api_key() + assert result == 'test_key_123' + + def test_get_api_key_missing(self): + with patch.dict('os.environ', {}, clear=True): + with pytest.raises(ValueError, match="TAVILY_API_KEY environment variable is not set"): + get_api_key() + + +class TestSearchWithRetry: + + def test_successful_search(self): + mock_client = Mock() + mock_client.search.return_value = {"results": []} + + result = _search_with_retry( + client=mock_client, + query="test query", + search_depth="advanced", + topic="news", + time_range="day", + max_results=10, + ) + + assert result == {"results": []} + mock_client.search.assert_called_once() + + @patch('tradingagents.dataflows.tavily.time.sleep') + def test_retry_on_rate_limit(self, mock_sleep): + mock_client = Mock() + mock_client.search.side_effect = [ + Exception("Rate limit exceeded"), + {"results": []}, + ] + + result = _search_with_retry( + client=mock_client, + query="test query", + search_depth="advanced", + topic="news", + time_range="day", + max_results=10, + ) + + assert result == {"results": []} + assert mock_client.search.call_count == 2 + assert mock_sleep.call_count == 1 + + @patch('tradingagents.dataflows.tavily.time.sleep') + def test_retry_on_timeout(self, mock_sleep): + mock_client = Mock() + mock_client.search.side_effect = [ + Exception("Request timed out"), + {"results": []}, + ] + + result = _search_with_retry( + client=mock_client, + query="test query", + search_depth="advanced", + topic="news", + time_range="day", + max_results=10, + ) + + assert result == {"results": []} + assert mock_client.search.call_count == 2 + + @patch('tradingagents.dataflows.tavily.time.sleep') + def test_retry_on_connection_error(self, mock_sleep): + mock_client = Mock() + mock_client.search.side_effect = [ + Exception("Connection error occurred"), + {"results": []}, + ] + + result = _search_with_retry( + client=mock_client, + query="test query", + search_depth="advanced", + topic="news", + time_range="day", + max_results=10, + ) + + assert result == {"results": []} + assert mock_client.search.call_count == 2 + + @patch('tradingagents.dataflows.tavily.time.sleep') + def test_max_retries_exceeded(self, mock_sleep): + mock_client = Mock() + mock_client.search.side_effect = Exception("Rate limit 429") + + with pytest.raises(Exception, match="Rate limit 429"): + _search_with_retry( + client=mock_client, + query="test query", + search_depth="advanced", + topic="news", + time_range="day", + max_results=10, + max_retries=3, + ) + + assert mock_client.search.call_count == 3 + + def test_non_retryable_error(self): + mock_client = Mock() + mock_client.search.side_effect = Exception("Invalid API key") + + with pytest.raises(Exception, match="Invalid API key"): + _search_with_retry( + client=mock_client, + query="test query", + search_depth="advanced", + topic="news", + time_range="day", + max_results=10, + ) + + assert mock_client.search.call_count == 1 + + +class TestGetBulkNewsTavily: + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', False) + def test_returns_empty_when_library_not_installed(self): + result = get_bulk_news_tavily(24) + assert result == [] + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True) + @patch('tradingagents.dataflows.tavily.TavilyClient') + @patch('tradingagents.dataflows.tavily.get_api_key') + def test_returns_empty_when_no_api_key(self, mock_get_api_key, mock_client_class): + mock_get_api_key.side_effect = ValueError("TAVILY_API_KEY not set") + + result = get_bulk_news_tavily(24) + + assert result == [] + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True) + @patch('tradingagents.dataflows.tavily.TavilyClient') + @patch('tradingagents.dataflows.tavily.get_api_key') + @patch('tradingagents.dataflows.tavily._search_with_retry') + def test_basic_call(self, mock_search, mock_get_api_key, mock_client_class): + mock_get_api_key.return_value = "test_key" + mock_search.return_value = {"results": []} + + result = get_bulk_news_tavily(24) + + assert isinstance(result, list) + assert mock_search.call_count == 5 + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True) + @patch('tradingagents.dataflows.tavily.TavilyClient') + @patch('tradingagents.dataflows.tavily.get_api_key') + @patch('tradingagents.dataflows.tavily._search_with_retry') + def test_parses_articles(self, mock_search, mock_get_api_key, mock_client_class): + mock_get_api_key.return_value = "test_key" + + mock_article = { + "title": "Test Stock News", + "url": "https://reuters.com/article1", + "published_date": "2024-01-15T10:30:00Z", + "content": "This is a test article about stocks.", + } + + mock_search.return_value = {"results": [mock_article]} + + result = get_bulk_news_tavily(24) + + assert len(result) >= 1 + article = result[0] + assert article["title"] == "Test Stock News" + assert article["source"] == "Tavily" + assert article["url"] == "https://reuters.com/article1" + assert "published_at" in article + assert article["content_snippet"] == "This is a test article about stocks." + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True) + @patch('tradingagents.dataflows.tavily.TavilyClient') + @patch('tradingagents.dataflows.tavily.get_api_key') + @patch('tradingagents.dataflows.tavily._search_with_retry') + def test_deduplicates_by_url(self, mock_search, mock_get_api_key, mock_client_class): + mock_get_api_key.return_value = "test_key" + + duplicate_article = { + "title": "Duplicate Article", + "url": "https://news.com/same-url", + "published_date": "2024-01-15T10:30:00Z", + "content": "Duplicate content.", + } + + mock_search.return_value = {"results": [duplicate_article, duplicate_article]} + + result = get_bulk_news_tavily(24) + + urls = [a["url"] for a in result] + assert len(urls) == len(set(urls)) + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True) + @patch('tradingagents.dataflows.tavily.TavilyClient') + @patch('tradingagents.dataflows.tavily.get_api_key') + @patch('tradingagents.dataflows.tavily._search_with_retry') + def test_truncates_long_content(self, mock_search, mock_get_api_key, mock_client_class): + mock_get_api_key.return_value = "test_key" + + long_content = "A" * 1000 + + mock_article = { + "title": "Long Article", + "url": "https://news.com/article", + "published_date": "2024-01-15T10:30:00Z", + "content": long_content, + } + + mock_search.return_value = {"results": [mock_article]} + + result = get_bulk_news_tavily(24) + + assert len(result[0]["content_snippet"]) == 500 + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True) + @patch('tradingagents.dataflows.tavily.TavilyClient') + @patch('tradingagents.dataflows.tavily.get_api_key') + @patch('tradingagents.dataflows.tavily._search_with_retry') + def test_time_range_day(self, mock_search, mock_get_api_key, mock_client_class): + mock_get_api_key.return_value = "test_key" + mock_search.return_value = {"results": []} + + get_bulk_news_tavily(24) + + call_kwargs = mock_search.call_args_list[0][1] + assert call_kwargs["time_range"] == "day" + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True) + @patch('tradingagents.dataflows.tavily.TavilyClient') + @patch('tradingagents.dataflows.tavily.get_api_key') + @patch('tradingagents.dataflows.tavily._search_with_retry') + def test_time_range_week(self, mock_search, mock_get_api_key, mock_client_class): + mock_get_api_key.return_value = "test_key" + mock_search.return_value = {"results": []} + + get_bulk_news_tavily(168) + + call_kwargs = mock_search.call_args_list[0][1] + assert call_kwargs["time_range"] == "week" + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True) + @patch('tradingagents.dataflows.tavily.TavilyClient') + @patch('tradingagents.dataflows.tavily.get_api_key') + @patch('tradingagents.dataflows.tavily._search_with_retry') + def test_time_range_month(self, mock_search, mock_get_api_key, mock_client_class): + mock_get_api_key.return_value = "test_key" + mock_search.return_value = {"results": []} + + get_bulk_news_tavily(720) + + call_kwargs = mock_search.call_args_list[0][1] + assert call_kwargs["time_range"] == "month" + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True) + @patch('tradingagents.dataflows.tavily.TavilyClient') + @patch('tradingagents.dataflows.tavily.get_api_key') + @patch('tradingagents.dataflows.tavily._search_with_retry') + def test_handles_missing_published_date(self, mock_search, mock_get_api_key, mock_client_class): + mock_get_api_key.return_value = "test_key" + + mock_article = { + "title": "Article Without Date", + "url": "https://news.com/article", + "content": "Content", + } + + mock_search.return_value = {"results": [mock_article]} + + result = get_bulk_news_tavily(24) + + assert len(result) == 1 + assert "published_at" in result[0] + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True) + @patch('tradingagents.dataflows.tavily.TavilyClient') + @patch('tradingagents.dataflows.tavily.get_api_key') + @patch('tradingagents.dataflows.tavily._search_with_retry') + def test_handles_invalid_date_format(self, mock_search, mock_get_api_key, mock_client_class): + mock_get_api_key.return_value = "test_key" + + mock_article = { + "title": "Article With Bad Date", + "url": "https://news.com/article", + "published_date": "invalid_date_format", + "content": "Content", + } + + mock_search.return_value = {"results": [mock_article]} + + result = get_bulk_news_tavily(24) + + assert len(result) == 1 + assert "published_at" in result[0] + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True) + @patch('tradingagents.dataflows.tavily.TavilyClient') + @patch('tradingagents.dataflows.tavily.get_api_key') + @patch('tradingagents.dataflows.tavily._search_with_retry') + def test_continues_on_query_failure(self, mock_search, mock_get_api_key, mock_client_class): + mock_get_api_key.return_value = "test_key" + + mock_search.side_effect = [ + Exception("Query failed"), + {"results": [{"title": "Article", "url": "https://test.com", "content": "test"}]}, + {"results": []}, + {"results": []}, + {"results": []}, + ] + + result = get_bulk_news_tavily(24) + + assert len(result) > 0 + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True) + @patch('tradingagents.dataflows.tavily.TavilyClient') + @patch('tradingagents.dataflows.tavily.get_api_key') + @patch('tradingagents.dataflows.tavily._search_with_retry') + def test_skips_articles_without_url(self, mock_search, mock_get_api_key, mock_client_class): + mock_get_api_key.return_value = "test_key" + + mock_articles = [ + {"title": "No URL Article", "content": "test"}, + {"title": "Has URL", "url": "https://test.com", "content": "test"}, + ] + + mock_search.return_value = {"results": mock_articles} + + result = get_bulk_news_tavily(24) + + urls = [a["url"] for a in result if a.get("url")] + assert all(url for url in urls) + + @patch('tradingagents.dataflows.tavily.TAVILY_AVAILABLE', True) + @patch('tradingagents.dataflows.tavily.TavilyClient') + @patch('tradingagents.dataflows.tavily.get_api_key') + @patch('tradingagents.dataflows.tavily._search_with_retry') + def test_uses_correct_search_parameters(self, mock_search, mock_get_api_key, mock_client_class): + mock_get_api_key.return_value = "test_key" + mock_search.return_value = {"results": []} + + get_bulk_news_tavily(24) + + call_kwargs = mock_search.call_args_list[0][1] + assert call_kwargs["search_depth"] == "advanced" + assert call_kwargs["topic"] == "news" + assert call_kwargs["max_results"] == 10 diff --git a/tradingagents/dataflows/brave.py b/tradingagents/dataflows/brave.py new file mode 100644 index 00000000..b355f9a9 --- /dev/null +++ b/tradingagents/dataflows/brave.py @@ -0,0 +1,150 @@ +import os +import time +import requests +from datetime import datetime, timedelta +from typing import List, Dict, Any + +BRAVE_SEARCH_URL = "https://api.search.brave.com/res/v1/news/search" +DEFAULT_TIMEOUT = 30 +MAX_RETRIES = 3 +RETRY_BACKOFF = 1.0 + + +def get_api_key() -> str: + api_key = os.getenv("BRAVE_API_KEY") + if not api_key: + raise ValueError("BRAVE_API_KEY environment variable is not set.") + return api_key + + +def _make_request_with_retry(url: str, headers: Dict, params: Dict, max_retries: int = MAX_RETRIES) -> requests.Response: + last_exception = None + for attempt in range(max_retries): + try: + response = requests.get(url, headers=headers, params=params, timeout=DEFAULT_TIMEOUT) + if response.status_code == 429: + retry_after = int(response.headers.get("Retry-After", RETRY_BACKOFF * (attempt + 1))) + print(f"DEBUG: Brave rate limited, waiting {retry_after}s before retry {attempt + 1}/{max_retries}") + time.sleep(retry_after) + continue + response.raise_for_status() + return response + except requests.exceptions.Timeout as e: + last_exception = e + print(f"DEBUG: Brave request timeout, retry {attempt + 1}/{max_retries}") + time.sleep(RETRY_BACKOFF * (attempt + 1)) + except requests.exceptions.ConnectionError as e: + last_exception = e + print(f"DEBUG: Brave connection error, retry {attempt + 1}/{max_retries}") + time.sleep(RETRY_BACKOFF * (attempt + 1)) + except requests.exceptions.HTTPError as e: + if e.response is not None and e.response.status_code >= 500: + last_exception = e + print(f"DEBUG: Brave server error {e.response.status_code}, retry {attempt + 1}/{max_retries}") + time.sleep(RETRY_BACKOFF * (attempt + 1)) + else: + raise + raise last_exception if last_exception else requests.exceptions.RequestException("Max retries exceeded") + + +def get_bulk_news_brave(lookback_hours: int) -> List[Dict[str, Any]]: + try: + api_key = get_api_key() + except ValueError as e: + print(f"DEBUG: Brave API key not configured: {e}") + return [] + + headers = { + "Accept": "application/json", + "Accept-Encoding": "gzip", + "X-Subscription-Token": api_key, + } + + queries = [ + "stock market news", + "earnings report", + "merger acquisition", + "company financial news", + "trading stocks", + ] + + all_articles = [] + seen_urls = set() + + if lookback_hours <= 24: + freshness = "pd" + elif lookback_hours <= 168: + freshness = "pw" + else: + freshness = "pm" + + for query in queries: + try: + params = { + "q": query, + "count": 20, + "freshness": freshness, + } + + response = _make_request_with_retry(BRAVE_SEARCH_URL, headers, params) + + data = response.json() + results = data.get("results", []) + + for item in results: + url = item.get("url", "") + if url and url not in seen_urls: + seen_urls.add(url) + + age = item.get("age", "") + published_at = _parse_brave_age(age) + + article = { + "title": item.get("title", ""), + "source": item.get("meta_url", {}).get("netloc", "Brave News"), + "url": url, + "published_at": published_at.isoformat(), + "content_snippet": item.get("description", "")[:500], + } + all_articles.append(article) + + except requests.exceptions.HTTPError as e: + print(f"DEBUG: Brave search HTTP error for '{query}': {e}") + continue + except requests.exceptions.Timeout as e: + print(f"DEBUG: Brave search timeout for '{query}': {e}") + continue + except requests.exceptions.RequestException as e: + print(f"DEBUG: Brave search request failed for '{query}': {e}") + continue + except Exception as e: + print(f"DEBUG: Brave search failed for query '{query}': {e}") + continue + + print(f"DEBUG: Brave returned {len(all_articles)} articles") + return all_articles + + +def _parse_brave_age(age_str: str) -> datetime: + now = datetime.now() + if not age_str: + return now + + age_str = age_str.lower() + try: + if "hour" in age_str: + hours = int("".join(filter(str.isdigit, age_str)) or "1") + return now - timedelta(hours=hours) + elif "day" in age_str: + days = int("".join(filter(str.isdigit, age_str)) or "1") + return now - timedelta(days=days) + elif "week" in age_str: + weeks = int("".join(filter(str.isdigit, age_str)) or "1") + return now - timedelta(weeks=weeks) + elif "minute" in age_str: + minutes = int("".join(filter(str.isdigit, age_str)) or "1") + return now - timedelta(minutes=minutes) + except (ValueError, TypeError): + pass + + return now diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 6a91d5e4..eda7a7e4 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -18,6 +18,8 @@ from .alpha_vantage import ( ) from .alpha_vantage_news import get_bulk_news_alpha_vantage from .alpha_vantage_common import AlphaVantageRateLimitError +from .tavily import get_bulk_news_tavily +from .brave import get_bulk_news_brave from .config import get_config @@ -113,6 +115,8 @@ VENDOR_METHODS = { "local": get_finnhub_company_insider_transactions, }, "get_bulk_news": { + "tavily": get_bulk_news_tavily, + "brave": get_bulk_news_brave, "alpha_vantage": get_bulk_news_alpha_vantage, "openai": get_bulk_news_openai, "google": get_bulk_news_google, @@ -192,7 +196,8 @@ def _convert_to_news_articles(raw_articles: List[Dict[str, Any]]) -> List[NewsAr def _fetch_bulk_news_from_vendor(lookback_period: str) -> List[Dict[str, Any]]: lookback_hours = parse_lookback_period(lookback_period) - vendor_order = ["alpha_vantage", "openai", "google"] + config = get_config() + vendor_order = config.get("bulk_news_vendor_order", ["tavily", "brave", "alpha_vantage", "openai", "google"]) for vendor in vendor_order: if vendor not in VENDOR_METHODS["get_bulk_news"]: diff --git a/tradingagents/dataflows/tavily.py b/tradingagents/dataflows/tavily.py new file mode 100644 index 00000000..560202a0 --- /dev/null +++ b/tradingagents/dataflows/tavily.py @@ -0,0 +1,128 @@ +import os +import time +from datetime import datetime, timedelta +from typing import List, Dict, Any + +try: + from tavily import TavilyClient + TAVILY_AVAILABLE = True +except ImportError: + TAVILY_AVAILABLE = False + +DEFAULT_TIMEOUT = 30 +MAX_RETRIES = 3 +RETRY_BACKOFF = 1.0 + + +def get_api_key() -> str: + api_key = os.getenv("TAVILY_API_KEY") + if not api_key: + raise ValueError("TAVILY_API_KEY environment variable is not set.") + return api_key + + +def _search_with_retry(client, query: str, search_depth: str, topic: str, time_range: str, max_results: int, max_retries: int = MAX_RETRIES) -> Dict[str, Any]: + last_exception = None + for attempt in range(max_retries): + try: + response = client.search( + query=query, + search_depth=search_depth, + topic=topic, + time_range=time_range, + max_results=max_results, + ) + return response + except Exception as e: + error_str = str(e).lower() + if "rate" in error_str or "limit" in error_str or "429" in error_str: + wait_time = RETRY_BACKOFF * (attempt + 1) * 2 + print(f"DEBUG: Tavily rate limited, waiting {wait_time}s before retry {attempt + 1}/{max_retries}") + time.sleep(wait_time) + last_exception = e + elif "timeout" in error_str or "timed out" in error_str: + wait_time = RETRY_BACKOFF * (attempt + 1) + print(f"DEBUG: Tavily timeout, waiting {wait_time}s before retry {attempt + 1}/{max_retries}") + time.sleep(wait_time) + last_exception = e + elif "connection" in error_str or "network" in error_str: + wait_time = RETRY_BACKOFF * (attempt + 1) + print(f"DEBUG: Tavily connection error, waiting {wait_time}s before retry {attempt + 1}/{max_retries}") + time.sleep(wait_time) + last_exception = e + else: + raise + raise last_exception if last_exception else Exception("Max retries exceeded") + + +def get_bulk_news_tavily(lookback_hours: int) -> List[Dict[str, Any]]: + if not TAVILY_AVAILABLE: + print("DEBUG: Tavily library not installed") + return [] + + try: + client = TavilyClient(api_key=get_api_key()) + except ValueError as e: + print(f"DEBUG: Tavily API key not configured: {e}") + return [] + + queries = [ + "stock market news today", + "earnings report announcement", + "merger acquisition deal", + "IPO stock market", + "company financial results", + ] + + days = max(1, lookback_hours // 24) + if lookback_hours <= 24: + time_range = "day" + elif lookback_hours <= 168: + time_range = "week" + else: + time_range = "month" + + all_articles = [] + seen_urls = set() + + for query in queries: + try: + response = _search_with_retry( + client=client, + query=query, + search_depth="advanced", + topic="news", + time_range=time_range, + max_results=10, + ) + + results = response.get("results", []) + for item in results: + url = item.get("url", "") + if url and url not in seen_urls: + seen_urls.add(url) + + published_date = item.get("published_date") + if published_date: + try: + published_at = datetime.fromisoformat(published_date.replace("Z", "+00:00")) + except (ValueError, TypeError): + published_at = datetime.now() + else: + published_at = datetime.now() + + article = { + "title": item.get("title", ""), + "source": "Tavily", + "url": url, + "published_at": published_at.isoformat(), + "content_snippet": item.get("content", "")[:500], + } + all_articles.append(article) + + except Exception as e: + print(f"DEBUG: Tavily search failed for query '{query}': {e}") + continue + + print(f"DEBUG: Tavily returned {len(all_articles)} articles") + return all_articles diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index e88868d6..f0d37cb0 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -28,4 +28,7 @@ DEFAULT_CONFIG = { "discovery_cache_ttl": 300, "discovery_max_results": 20, "discovery_min_mentions": 2, + "bulk_news_vendor_order": ["tavily", "brave", "alpha_vantage", "openai", "google"], + "bulk_news_timeout": 30, + "bulk_news_max_retries": 3, }