163 lines
5.0 KiB
Python
163 lines
5.0 KiB
Python
import logging
|
|
import os
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
from tavily import TavilyClient
|
|
|
|
TAVILY_AVAILABLE = True
|
|
except ImportError:
|
|
TAVILY_AVAILABLE = False
|
|
|
|
DEFAULT_TIMEOUT = 30
|
|
MAX_RETRIES = 3
|
|
RETRY_BACKOFF = 1.0
|
|
|
|
|
|
def get_api_key() -> str:
|
|
try:
|
|
from tradingagents.config import get_settings
|
|
|
|
return get_settings().require_api_key("tavily")
|
|
except ImportError:
|
|
api_key = os.getenv("TAVILY_API_KEY")
|
|
if not api_key:
|
|
raise ValueError("TAVILY_API_KEY environment variable is not set.")
|
|
return api_key
|
|
|
|
|
|
def _search_with_retry(
|
|
client,
|
|
query: str,
|
|
search_depth: str,
|
|
topic: str,
|
|
time_range: str,
|
|
max_results: int,
|
|
max_retries: int = MAX_RETRIES,
|
|
) -> dict[str, Any]:
|
|
last_exception = None
|
|
for attempt in range(max_retries):
|
|
try:
|
|
response = client.search(
|
|
query=query,
|
|
search_depth=search_depth,
|
|
topic=topic,
|
|
time_range=time_range,
|
|
max_results=max_results,
|
|
)
|
|
return response
|
|
except (RuntimeError, ConnectionError, TimeoutError, OSError) as e:
|
|
error_str = str(e).lower()
|
|
if "rate" in error_str or "limit" in error_str or "429" in error_str:
|
|
wait_time = RETRY_BACKOFF * (attempt + 1) * 2
|
|
logger.debug(
|
|
"Tavily rate limited, waiting %ds before retry %d/%d",
|
|
wait_time,
|
|
attempt + 1,
|
|
max_retries,
|
|
)
|
|
time.sleep(wait_time)
|
|
last_exception = e
|
|
elif "timeout" in error_str or "timed out" in error_str:
|
|
wait_time = RETRY_BACKOFF * (attempt + 1)
|
|
logger.debug(
|
|
"Tavily timeout, waiting %ds before retry %d/%d",
|
|
wait_time,
|
|
attempt + 1,
|
|
max_retries,
|
|
)
|
|
time.sleep(wait_time)
|
|
last_exception = e
|
|
elif "connection" in error_str or "network" in error_str:
|
|
wait_time = RETRY_BACKOFF * (attempt + 1)
|
|
logger.debug(
|
|
"Tavily connection error, waiting %ds before retry %d/%d",
|
|
wait_time,
|
|
attempt + 1,
|
|
max_retries,
|
|
)
|
|
time.sleep(wait_time)
|
|
last_exception = e
|
|
else:
|
|
raise
|
|
raise last_exception if last_exception else Exception("Max retries exceeded")
|
|
|
|
|
|
def get_bulk_news_tavily(lookback_hours: int) -> list[dict[str, Any]]:
|
|
if not TAVILY_AVAILABLE:
|
|
logger.debug("Tavily library not installed")
|
|
return []
|
|
|
|
try:
|
|
client = TavilyClient(api_key=get_api_key())
|
|
except ValueError as e:
|
|
logger.debug("Tavily API key not configured: %s", e)
|
|
return []
|
|
|
|
queries = [
|
|
"stock market news today",
|
|
"earnings report announcement",
|
|
"merger acquisition deal",
|
|
"IPO stock market",
|
|
"company financial results",
|
|
]
|
|
|
|
days = max(1, lookback_hours // 24)
|
|
if lookback_hours <= 24:
|
|
time_range = "day"
|
|
elif lookback_hours <= 168:
|
|
time_range = "week"
|
|
else:
|
|
time_range = "month"
|
|
|
|
all_articles = []
|
|
seen_urls = set()
|
|
|
|
for query in queries:
|
|
try:
|
|
response = _search_with_retry(
|
|
client=client,
|
|
query=query,
|
|
search_depth="advanced",
|
|
topic="news",
|
|
time_range=time_range,
|
|
max_results=10,
|
|
)
|
|
|
|
results = response.get("results", [])
|
|
for item in results:
|
|
url = item.get("url", "")
|
|
if url and url not in seen_urls:
|
|
seen_urls.add(url)
|
|
|
|
published_date = item.get("published_date")
|
|
if published_date:
|
|
try:
|
|
published_at = datetime.fromisoformat(
|
|
published_date.replace("Z", "+00:00")
|
|
)
|
|
except (ValueError, TypeError):
|
|
published_at = datetime.now()
|
|
else:
|
|
published_at = datetime.now()
|
|
|
|
article = {
|
|
"title": item.get("title", ""),
|
|
"source": "Tavily",
|
|
"url": url,
|
|
"published_at": published_at.isoformat(),
|
|
"content_snippet": item.get("content", "")[:500],
|
|
}
|
|
all_articles.append(article)
|
|
|
|
except (RuntimeError, ConnectionError, TimeoutError, OSError, ValueError) as e:
|
|
logger.debug("Tavily search failed for query '%s': %s", query, e)
|
|
continue
|
|
|
|
logger.debug("Tavily returned %d articles", len(all_articles))
|
|
return all_articles
|