471 lines
15 KiB
Python
471 lines
15 KiB
Python
import logging
|
|
import threading
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from tradingagents.agents.discovery import NewsArticle
|
|
|
|
from .alpha_vantage import get_balance_sheet as get_alpha_vantage_balance_sheet
|
|
from .alpha_vantage import get_cashflow as get_alpha_vantage_cashflow
|
|
from .alpha_vantage import get_fundamentals as get_alpha_vantage_fundamentals
|
|
from .alpha_vantage import get_income_statement as get_alpha_vantage_income_statement
|
|
from .alpha_vantage import get_indicator as get_alpha_vantage_indicator
|
|
from .alpha_vantage import (
|
|
get_insider_transactions as get_alpha_vantage_insider_transactions,
|
|
)
|
|
from .alpha_vantage import get_news as get_alpha_vantage_news
|
|
from .alpha_vantage import get_stock as get_alpha_vantage_stock
|
|
from .alpha_vantage_common import AlphaVantageRateLimitError
|
|
from .alpha_vantage_news import get_bulk_news_alpha_vantage
|
|
from .brave import get_bulk_news_brave
|
|
from .config import get_config
|
|
from .google import get_bulk_news_google, get_google_news
|
|
from .local import (
|
|
get_finnhub_company_insider_sentiment,
|
|
get_finnhub_company_insider_transactions,
|
|
get_finnhub_news,
|
|
get_reddit_company_news,
|
|
get_reddit_global_news,
|
|
get_simfin_balance_sheet,
|
|
get_simfin_cashflow,
|
|
get_simfin_income_statements,
|
|
get_YFin_data,
|
|
)
|
|
from .openai import (
|
|
get_bulk_news_openai,
|
|
get_fundamentals_openai,
|
|
get_global_news_openai,
|
|
get_stock_news_openai,
|
|
)
|
|
from .tavily import get_bulk_news_tavily
|
|
from .y_finance import get_balance_sheet as get_yfinance_balance_sheet
|
|
from .y_finance import get_cashflow as get_yfinance_cashflow
|
|
from .y_finance import get_income_statement as get_yfinance_income_statement
|
|
from .y_finance import get_insider_transactions as get_yfinance_insider_transactions
|
|
from .y_finance import get_stock_stats_indicators_window, get_YFin_data_online
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TOOLS_CATEGORIES = {
|
|
"core_stock_apis": {
|
|
"description": "OHLCV stock price data",
|
|
"tools": ["get_stock_data"],
|
|
},
|
|
"technical_indicators": {
|
|
"description": "Technical analysis indicators",
|
|
"tools": ["get_indicators"],
|
|
},
|
|
"fundamental_data": {
|
|
"description": "Company fundamentals",
|
|
"tools": [
|
|
"get_fundamentals",
|
|
"get_balance_sheet",
|
|
"get_cashflow",
|
|
"get_income_statement",
|
|
],
|
|
},
|
|
"news_data": {
|
|
"description": "News (public/insiders, original/processed)",
|
|
"tools": [
|
|
"get_news",
|
|
"get_global_news",
|
|
"get_insider_sentiment",
|
|
"get_insider_transactions",
|
|
"get_bulk_news",
|
|
],
|
|
},
|
|
}
|
|
|
|
VENDOR_LIST = ["local", "yfinance", "openai", "google"]
|
|
|
|
VENDOR_METHODS = {
|
|
"get_stock_data": {
|
|
"alpha_vantage": get_alpha_vantage_stock,
|
|
"yfinance": get_YFin_data_online,
|
|
"local": get_YFin_data,
|
|
},
|
|
"get_indicators": {
|
|
"alpha_vantage": get_alpha_vantage_indicator,
|
|
"yfinance": get_stock_stats_indicators_window,
|
|
"local": get_stock_stats_indicators_window,
|
|
},
|
|
"get_fundamentals": {
|
|
"alpha_vantage": get_alpha_vantage_fundamentals,
|
|
"openai": get_fundamentals_openai,
|
|
},
|
|
"get_balance_sheet": {
|
|
"alpha_vantage": get_alpha_vantage_balance_sheet,
|
|
"yfinance": get_yfinance_balance_sheet,
|
|
"local": get_simfin_balance_sheet,
|
|
},
|
|
"get_cashflow": {
|
|
"alpha_vantage": get_alpha_vantage_cashflow,
|
|
"yfinance": get_yfinance_cashflow,
|
|
"local": get_simfin_cashflow,
|
|
},
|
|
"get_income_statement": {
|
|
"alpha_vantage": get_alpha_vantage_income_statement,
|
|
"yfinance": get_yfinance_income_statement,
|
|
"local": get_simfin_income_statements,
|
|
},
|
|
"get_news": {
|
|
"alpha_vantage": get_alpha_vantage_news,
|
|
"openai": get_stock_news_openai,
|
|
"google": get_google_news,
|
|
"local": [get_finnhub_news, get_reddit_company_news, get_google_news],
|
|
},
|
|
"get_global_news": {
|
|
"openai": get_global_news_openai,
|
|
"local": get_reddit_global_news,
|
|
},
|
|
"get_insider_sentiment": {"local": get_finnhub_company_insider_sentiment},
|
|
"get_insider_transactions": {
|
|
"alpha_vantage": get_alpha_vantage_insider_transactions,
|
|
"yfinance": get_yfinance_insider_transactions,
|
|
"local": get_finnhub_company_insider_transactions,
|
|
},
|
|
"get_bulk_news": {
|
|
"tavily": get_bulk_news_tavily,
|
|
"brave": get_bulk_news_brave,
|
|
"alpha_vantage": get_bulk_news_alpha_vantage,
|
|
"openai": get_bulk_news_openai,
|
|
"google": get_bulk_news_google,
|
|
},
|
|
}
|
|
|
|
CACHE_TTL_SECONDS = 300
|
|
|
|
_bulk_news_cache: dict[str, dict[str, Any]] = {}
|
|
_bulk_news_cache_lock = threading.Lock()
|
|
|
|
|
|
def _is_db_cache_enabled() -> bool:
|
|
config = get_config()
|
|
return config.get("database_enabled", False) and config.get(
|
|
"db_cache_enabled", True
|
|
)
|
|
|
|
|
|
def _get_db_cached_data(method: str, ticker: str | None = None, **kwargs: Any) -> Any:
|
|
if not _is_db_cache_enabled():
|
|
return None
|
|
|
|
try:
|
|
from tradingagents.database import MarketDataService, get_db_session
|
|
|
|
with get_db_session() as session:
|
|
service = MarketDataService(session)
|
|
return service.get_cached_data(method, ticker, **kwargs)
|
|
except (ImportError, RuntimeError, ConnectionError) as e:
|
|
logger.debug("DB cache lookup failed: %s", e)
|
|
return None
|
|
|
|
|
|
def _set_db_cached_data(
|
|
method: str,
|
|
data: Any,
|
|
ticker: str | None = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
if not _is_db_cache_enabled():
|
|
return
|
|
|
|
try:
|
|
from tradingagents.database import (
|
|
MarketDataService,
|
|
get_db_session,
|
|
get_default_ttl,
|
|
)
|
|
|
|
ttl_hours = get_default_ttl(method)
|
|
with get_db_session() as session:
|
|
service = MarketDataService(session)
|
|
service.set_cached_data(method, data, ticker, ttl_hours=ttl_hours, **kwargs)
|
|
except (ImportError, RuntimeError, ConnectionError) as e:
|
|
logger.debug("DB cache write failed: %s", e)
|
|
|
|
|
|
def parse_lookback_period(lookback: str) -> int:
|
|
lookback = lookback.lower().strip()
|
|
|
|
if lookback == "1h":
|
|
return 1
|
|
elif lookback == "6h":
|
|
return 6
|
|
elif lookback == "24h":
|
|
return 24
|
|
elif lookback == "7d":
|
|
return 168
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid lookback period: {lookback}. Valid values: 1h, 6h, 24h, 7d"
|
|
)
|
|
|
|
|
|
def _get_cached_bulk_news(lookback_period: str) -> list[NewsArticle] | None:
|
|
cache_key = lookback_period
|
|
with _bulk_news_cache_lock:
|
|
if cache_key in _bulk_news_cache:
|
|
cached = _bulk_news_cache[cache_key]
|
|
cached_time = cached.get("timestamp")
|
|
if (
|
|
cached_time
|
|
and (datetime.now() - cached_time).total_seconds() < CACHE_TTL_SECONDS
|
|
):
|
|
return cached.get("articles")
|
|
return None
|
|
|
|
|
|
def _set_cached_bulk_news(lookback_period: str, articles: list[NewsArticle]) -> None:
|
|
cache_key = lookback_period
|
|
with _bulk_news_cache_lock:
|
|
_bulk_news_cache[cache_key] = {
|
|
"timestamp": datetime.now(),
|
|
"articles": articles,
|
|
}
|
|
|
|
|
|
def _convert_to_news_articles(raw_articles: list[dict[str, Any]]) -> list[NewsArticle]:
|
|
articles = []
|
|
for item in raw_articles:
|
|
try:
|
|
published_at_str = item.get("published_at", "")
|
|
if isinstance(published_at_str, str):
|
|
try:
|
|
published_at = datetime.fromisoformat(
|
|
published_at_str.replace("Z", "+00:00")
|
|
)
|
|
except ValueError:
|
|
published_at = datetime.now()
|
|
elif isinstance(published_at_str, datetime):
|
|
published_at = published_at_str
|
|
else:
|
|
published_at = datetime.now()
|
|
|
|
article = NewsArticle(
|
|
title=item.get("title", ""),
|
|
source=item.get("source", ""),
|
|
url=item.get("url", ""),
|
|
published_at=published_at,
|
|
content_snippet=item.get("content_snippet", ""),
|
|
ticker_mentions=[],
|
|
)
|
|
articles.append(article)
|
|
except (KeyError, TypeError, ValueError) as e:
|
|
logger.debug("Error converting article to NewsArticle: %s", e)
|
|
continue
|
|
return articles
|
|
|
|
|
|
def _fetch_bulk_news_from_vendor(lookback_period: str) -> list[dict[str, Any]]:
|
|
lookback_hours = parse_lookback_period(lookback_period)
|
|
|
|
config = get_config()
|
|
vendor_order = config.get(
|
|
"bulk_news_vendor_order",
|
|
["tavily", "brave", "alpha_vantage", "openai", "google"],
|
|
)
|
|
|
|
for vendor in vendor_order:
|
|
if vendor not in VENDOR_METHODS["get_bulk_news"]:
|
|
continue
|
|
|
|
vendor_func = VENDOR_METHODS["get_bulk_news"][vendor]
|
|
|
|
try:
|
|
logger.debug("Attempting bulk news from vendor '%s'...", vendor)
|
|
result = vendor_func(lookback_hours)
|
|
if result:
|
|
logger.info("Got %d articles from vendor '%s'", len(result), vendor)
|
|
return result
|
|
logger.debug("Vendor '%s' returned empty results, trying next...", vendor)
|
|
except AlphaVantageRateLimitError as e:
|
|
logger.warning("Alpha Vantage rate limit exceeded: %s", e)
|
|
continue
|
|
except (RuntimeError, ConnectionError, TimeoutError, ValueError, OSError) as e:
|
|
logger.error("Vendor '%s' failed: %s", vendor, e)
|
|
continue
|
|
|
|
return []
|
|
|
|
|
|
def get_bulk_news(lookback_period: str = "24h") -> list[NewsArticle]:
|
|
cached = _get_cached_bulk_news(lookback_period)
|
|
if cached is not None:
|
|
logger.debug("Returning cached bulk news for period '%s'", lookback_period)
|
|
return cached
|
|
|
|
raw_articles = _fetch_bulk_news_from_vendor(lookback_period)
|
|
|
|
articles = _convert_to_news_articles(raw_articles)
|
|
|
|
_set_cached_bulk_news(lookback_period, articles)
|
|
|
|
return articles
|
|
|
|
|
|
def get_category_for_method(method: str) -> str:
|
|
for category, info in TOOLS_CATEGORIES.items():
|
|
if method in info["tools"]:
|
|
return category
|
|
raise ValueError(f"Method '{method}' not found in any category")
|
|
|
|
|
|
def get_vendor(category: str, method: str = None) -> str:
|
|
config = get_config()
|
|
|
|
if method:
|
|
tool_vendors = config.get("tool_vendors", {})
|
|
if method in tool_vendors:
|
|
return tool_vendors[method]
|
|
|
|
return config.get("data_vendors", {}).get(category, "default")
|
|
|
|
|
|
def route_to_vendor(method: str, *args, **kwargs):
|
|
ticker = kwargs.get("ticker") or (args[0] if args else None)
|
|
date = kwargs.get("date") or (args[1] if len(args) > 1 else None)
|
|
|
|
cached_result = _get_db_cached_data(method, ticker, date=date)
|
|
if cached_result is not None:
|
|
logger.info("DB cache hit for %s (ticker=%s)", method, ticker)
|
|
return cached_result
|
|
|
|
category = get_category_for_method(method)
|
|
vendor_config = get_vendor(category, method)
|
|
|
|
primary_vendors = [v.strip() for v in vendor_config.split(",")]
|
|
|
|
if method not in VENDOR_METHODS:
|
|
raise ValueError(f"Method '{method}' not supported")
|
|
|
|
all_available_vendors = list(VENDOR_METHODS[method].keys())
|
|
|
|
fallback_vendors = primary_vendors.copy()
|
|
for vendor in all_available_vendors:
|
|
if vendor not in fallback_vendors:
|
|
fallback_vendors.append(vendor)
|
|
|
|
primary_str = " -> ".join(primary_vendors)
|
|
fallback_str = " -> ".join(fallback_vendors)
|
|
logger.debug(
|
|
"%s - Primary: [%s] | Full fallback order: [%s]",
|
|
method,
|
|
primary_str,
|
|
fallback_str,
|
|
)
|
|
|
|
results = []
|
|
vendor_attempt_count = 0
|
|
any_primary_vendor_attempted = False
|
|
successful_vendor = None
|
|
|
|
for vendor in fallback_vendors:
|
|
if vendor not in VENDOR_METHODS[method]:
|
|
if vendor in primary_vendors:
|
|
logger.info(
|
|
"Vendor '%s' not supported for method '%s', falling back to next vendor",
|
|
vendor,
|
|
method,
|
|
)
|
|
continue
|
|
|
|
vendor_impl = VENDOR_METHODS[method][vendor]
|
|
is_primary_vendor = vendor in primary_vendors
|
|
vendor_attempt_count += 1
|
|
|
|
if is_primary_vendor:
|
|
any_primary_vendor_attempted = True
|
|
|
|
vendor_type = "PRIMARY" if is_primary_vendor else "FALLBACK"
|
|
logger.debug(
|
|
"Attempting %s vendor '%s' for %s (attempt #%d)",
|
|
vendor_type,
|
|
vendor,
|
|
method,
|
|
vendor_attempt_count,
|
|
)
|
|
|
|
if isinstance(vendor_impl, list):
|
|
vendor_methods = [(impl, vendor) for impl in vendor_impl]
|
|
logger.debug(
|
|
"Vendor '%s' has multiple implementations: %d functions",
|
|
vendor,
|
|
len(vendor_methods),
|
|
)
|
|
else:
|
|
vendor_methods = [(vendor_impl, vendor)]
|
|
|
|
vendor_results = []
|
|
for impl_func, vendor_name in vendor_methods:
|
|
try:
|
|
logger.debug(
|
|
"Calling %s from vendor '%s'...", impl_func.__name__, vendor_name
|
|
)
|
|
result = impl_func(*args, **kwargs)
|
|
vendor_results.append(result)
|
|
logger.info(
|
|
"%s from vendor '%s' completed successfully",
|
|
impl_func.__name__,
|
|
vendor_name,
|
|
)
|
|
|
|
except AlphaVantageRateLimitError as e:
|
|
if vendor == "alpha_vantage":
|
|
logger.warning(
|
|
"Alpha Vantage rate limit exceeded, falling back to next available vendor"
|
|
)
|
|
logger.debug("Rate limit details: %s", e)
|
|
continue
|
|
except (
|
|
RuntimeError,
|
|
ConnectionError,
|
|
TimeoutError,
|
|
ValueError,
|
|
KeyError,
|
|
OSError,
|
|
) as e:
|
|
logger.error(
|
|
"%s from vendor '%s' failed: %s", impl_func.__name__, vendor_name, e
|
|
)
|
|
continue
|
|
|
|
if vendor_results:
|
|
results.extend(vendor_results)
|
|
successful_vendor = vendor
|
|
result_summary = f"Got {len(vendor_results)} result(s)"
|
|
logger.info("Vendor '%s' succeeded - %s", vendor, result_summary)
|
|
|
|
if len(primary_vendors) == 1:
|
|
logger.debug(
|
|
"Stopping after successful vendor '%s' (single-vendor config)",
|
|
vendor,
|
|
)
|
|
break
|
|
else:
|
|
logger.error("Vendor '%s' produced no results", vendor)
|
|
|
|
if not results:
|
|
logger.error(
|
|
"All %d vendor attempts failed for method '%s'",
|
|
vendor_attempt_count,
|
|
method,
|
|
)
|
|
raise RuntimeError(f"All vendor implementations failed for method '{method}'")
|
|
else:
|
|
logger.info(
|
|
"Method '%s' completed with %d result(s) from %d vendor attempt(s)",
|
|
method,
|
|
len(results),
|
|
vendor_attempt_count,
|
|
)
|
|
|
|
if len(results) == 1:
|
|
final_result = results[0]
|
|
else:
|
|
final_result = "\n".join(str(result) for result in results)
|
|
|
|
_set_db_cached_data(method, final_result, ticker, date=date)
|
|
|
|
return final_result
|