This commit is contained in:
Youssef Aitousarrah 2025-12-06 15:39:49 -08:00
parent 5cf57e5d97
commit ccc78c694b
11 changed files with 552 additions and 43 deletions

View File

@ -39,7 +39,7 @@ Volatility Indicators:
Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then call get_indicators SEPARATELY for EACH indicator you want to analyze (e.g., call get_indicators once with indicator="rsi", then call it again with indicator="macd", etc.). Do NOT pass multiple indicators in a single call. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
)

View File

@ -1,6 +1,7 @@
from typing import Union, Dict, Optional
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
def get_news(ticker: str = None, start_date: str = None, end_date: str = None, query: str = None) -> dict[str, str] | str:
def get_news(ticker: str = None, start_date: str = None, end_date: str = None, query: str = None) -> Union[Dict[str, str], str]:
"""Returns live and historical market news & sentiment data.
Args:
@ -28,7 +29,7 @@ def get_news(ticker: str = None, start_date: str = None, end_date: str = None, q
return _make_api_request("NEWS_SENTIMENT", params)
def get_global_news(date: str, look_back_days: int = 7, limit: int = 5) -> dict[str, str] | str:
def get_global_news(date: str, look_back_days: int = 7, limit: int = 5) -> Union[Dict[str, str], str]:
"""Returns global market news & sentiment data.
Args:
@ -48,7 +49,7 @@ def get_global_news(date: str, look_back_days: int = 7, limit: int = 5) -> dict[
return _make_api_request("NEWS_SENTIMENT", params)
def get_insider_transactions(symbol: str = None, ticker: str = None, curr_date: str = None) -> dict[str, str] | str:
def get_insider_transactions(symbol: str = None, ticker: str = None, curr_date: str = None) -> Union[Dict[str, str], str]:
"""Returns latest and historical insider transactions.
Args:

View File

@ -63,3 +63,157 @@ def get_recommendation_trends(
except Exception as e:
return f"Error fetching recommendation trends for {ticker}: {str(e)}"
def get_earnings_calendar(
from_date: Annotated[str, "Start date in yyyy-mm-dd format"],
to_date: Annotated[str, "End date in yyyy-mm-dd format"]
) -> str:
"""
Get earnings calendar for stocks with upcoming earnings announcements.
Args:
from_date: Start date in yyyy-mm-dd format
to_date: End date in yyyy-mm-dd format
Returns:
str: Formatted report of upcoming earnings
"""
try:
client = get_finnhub_client()
data = client.earnings_calendar(
_from=from_date,
to=to_date,
symbol="", # Empty string returns all stocks
international=False
)
if not data or 'earningsCalendar' not in data:
return f"No earnings data found for period {from_date} to {to_date}"
earnings = data['earningsCalendar']
if not earnings:
return f"No earnings scheduled between {from_date} and {to_date}"
# Format the response
result = f"## Earnings Calendar ({from_date} to {to_date})\n\n"
result += f"**Total Companies**: {len(earnings)}\n\n"
# Group by date
by_date = {}
for entry in earnings:
date = entry.get('date', 'Unknown')
if date not in by_date:
by_date[date] = []
by_date[date].append(entry)
# Format by date
for date in sorted(by_date.keys()):
result += f"### {date}\n\n"
for entry in by_date[date]:
symbol = entry.get('symbol', 'N/A')
eps_estimate = entry.get('epsEstimate', 'N/A')
eps_actual = entry.get('epsActual', 'N/A')
revenue_estimate = entry.get('revenueEstimate', 'N/A')
revenue_actual = entry.get('revenueActual', 'N/A')
hour = entry.get('hour', 'N/A')
result += f"**{symbol}**"
if hour != 'N/A':
result += f" ({hour})"
result += "\n"
if eps_estimate != 'N/A':
result += f" - EPS Estimate: ${eps_estimate:.2f}" if isinstance(eps_estimate, (int, float)) else f" - EPS Estimate: {eps_estimate}"
if eps_actual != 'N/A':
result += f" | Actual: ${eps_actual:.2f}" if isinstance(eps_actual, (int, float)) else f" | Actual: {eps_actual}"
result += "\n"
if revenue_estimate != 'N/A':
result += f" - Revenue Estimate: ${revenue_estimate:,.0f}M" if isinstance(revenue_estimate, (int, float)) else f" - Revenue Estimate: {revenue_estimate}"
if revenue_actual != 'N/A':
result += f" | Actual: ${revenue_actual:,.0f}M" if isinstance(revenue_actual, (int, float)) else f" | Actual: {revenue_actual}"
result += "\n"
result += "\n"
return result
except Exception as e:
return f"Error fetching earnings calendar: {str(e)}"
def get_ipo_calendar(
from_date: Annotated[str, "Start date in yyyy-mm-dd format"],
to_date: Annotated[str, "End date in yyyy-mm-dd format"]
) -> str:
"""
Get IPO calendar for upcoming and recent initial public offerings.
Args:
from_date: Start date in yyyy-mm-dd format
to_date: End date in yyyy-mm-dd format
Returns:
str: Formatted report of IPOs
"""
try:
client = get_finnhub_client()
data = client.ipo_calendar(
_from=from_date,
to=to_date
)
if not data or 'ipoCalendar' not in data:
return f"No IPO data found for period {from_date} to {to_date}"
ipos = data['ipoCalendar']
if not ipos:
return f"No IPOs scheduled between {from_date} and {to_date}"
# Format the response
result = f"## IPO Calendar ({from_date} to {to_date})\n\n"
result += f"**Total IPOs**: {len(ipos)}\n\n"
# Group by date
by_date = {}
for entry in ipos:
date = entry.get('date', 'Unknown')
if date not in by_date:
by_date[date] = []
by_date[date].append(entry)
# Format by date
for date in sorted(by_date.keys()):
result += f"### {date}\n\n"
for entry in by_date[date]:
symbol = entry.get('symbol', 'N/A')
name = entry.get('name', 'N/A')
exchange = entry.get('exchange', 'N/A')
price = entry.get('price', 'N/A')
shares = entry.get('numberOfShares', 'N/A')
total_shares = entry.get('totalSharesValue', 'N/A')
status = entry.get('status', 'N/A')
result += f"**{symbol}** - {name}\n"
result += f" - Exchange: {exchange}\n"
if price != 'N/A':
result += f" - Price: ${price}\n"
if shares != 'N/A':
result += f" - Shares Offered: {shares:,}\n" if isinstance(shares, (int, float)) else f" - Shares Offered: {shares}\n"
if total_shares != 'N/A':
result += f" - Total Value: ${total_shares:,.0f}M\n" if isinstance(total_shares, (int, float)) else f" - Total Value: {total_shares}\n"
result += f" - Status: {status}\n\n"
return result
except Exception as e:
return f"Error fetching IPO calendar: {str(e)}"

View File

@ -5,13 +5,23 @@ from .googlenews_utils import getNewsData
def get_google_news(
query: Annotated[str, "Query to search with"],
query: Annotated[str, "Query to search with"] = None,
ticker: Annotated[str, "Ticker symbol (alias for query)"] = None,
curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"] = None,
look_back_days: Annotated[int, "how many days to look back"] = None,
start_date: Annotated[str, "Start date in yyyy-mm-dd format"] = None,
end_date: Annotated[str, "End date in yyyy-mm-dd format"] = None,
) -> str:
query = query.replace(" ", "+")
# Handle parameter aliasing (query or ticker)
if query:
search_query = query
elif ticker:
# Format ticker as a natural language query for better results
search_query = f"latest news on {ticker} stock"
else:
raise ValueError("Must provide either 'query' or 'ticker' parameter")
search_query = search_query.replace(" ", "+")
# Determine date range
if start_date and end_date:
@ -24,7 +34,7 @@ def get_google_news(
else:
raise ValueError("Must provide either (start_date, end_date) or (curr_date, look_back_days)")
news_results = getNewsData(query, before, target_date)
news_results = getNewsData(search_query, before, target_date)
news_str = ""
@ -36,7 +46,7 @@ def get_google_news(
if len(news_results) == 0:
return ""
return f"## {query} Google News, from {before} to {target_date}:\n\n{news_str}"
return f"## {search_query} Google News, from {before} to {target_date}:\n\n{news_str}"
def get_global_news_google(

View File

@ -2,7 +2,24 @@ from openai import OpenAI
from .config import get_config
def get_stock_news_openai(query, start_date, end_date):
def get_stock_news_openai(query=None, ticker=None, start_date=None, end_date=None):
"""Get stock news from OpenAI web search.
Args:
query: Search query or ticker symbol
ticker: Ticker symbol (alias for query)
start_date: Start date yyyy-mm-dd
end_date: End date yyyy-mm-dd
"""
# Handle parameter aliasing
if query:
search_query = query
elif ticker:
# Format ticker as a natural language query for better results
search_query = f"latest news and market analysis on {ticker} stock"
else:
raise ValueError("Must provide either 'query' or 'ticker' parameter")
config = get_config()
client = OpenAI(base_url=config["backend_url"])
@ -10,7 +27,7 @@ def get_stock_news_openai(query, start_date, end_date):
response = client.responses.create(
model="gpt-4o-mini",
tools=[{"type": "web_search_preview"}],
input=f"Search Social Media for {query} from {start_date} to {end_date}. Make sure you only get the data posted during that period."
input=f"Search Social Media and news sources for {search_query} from {start_date} to {end_date}. Make sure you only get the data posted during that period."
)
return response.output_text
except Exception as e:

View File

@ -11,7 +11,7 @@ from tradingagents.agents.utils.agent_utils import (
get_indicators
)
from tradingagents.tools.executor import execute_tool
from tradingagents.schemas import TickerList, MarketMovers
from tradingagents.schemas import TickerList, MarketMovers, ThemeList
class DiscoveryGraph:
def __init__(self, config=None):
@ -76,6 +76,66 @@ class DiscoveryGraph:
candidates = []
# 0. Macro Theme Discovery (Top-Down)
try:
from datetime import datetime
today = datetime.now().strftime("%Y-%m-%d")
# Get Global News
global_news = execute_tool("get_global_news", date=today, limit=5)
# Extract Themes
prompt = f"""Based on this global news, identify 3 trending market themes or sectors (e.g., 'Artificial Intelligence', 'Oil', 'Biotech').
Return a JSON object with a 'themes' array of strings.
News:
{global_news}
"""
structured_llm = self.quick_thinking_llm.with_structured_output(
schema=ThemeList.model_json_schema(),
method="json_schema"
)
response = structured_llm.invoke([HumanMessage(content=prompt)])
themes = response.get("themes", [])
print(f" Identified Macro Themes: {themes}")
# Find tickers for each theme
for theme in themes:
try:
tweets_report = execute_tool("get_tweets", query=f"{theme} stocks", count=15)
prompt = f"""Extract ONLY valid stock ticker symbols related to the theme '{theme}' from this report.
Return a comma-separated list of tickers (1-5 uppercase letters).
Report:
{tweets_report}
Return a JSON object with a 'tickers' array."""
structured_llm = self.quick_thinking_llm.with_structured_output(
schema=TickerList.model_json_schema(),
method="json_schema"
)
response = structured_llm.invoke([HumanMessage(content=prompt)])
theme_tickers = response.get("tickers", [])
for t in theme_tickers:
t = t.upper().strip()
if re.match(r'^[A-Z]{1,5}$', t):
# Use validate_ticker tool logic (via execute_tool)
try:
if execute_tool("validate_ticker", symbol=t):
candidates.append({"ticker": t, "source": f"macro_theme_{theme}", "sentiment": "unknown"})
except Exception:
continue
except Exception as e:
print(f" Error fetching tickers for theme {theme}: {e}")
except Exception as e:
print(f" Error in Macro Theme Discovery: {e}")
# 1. Get Reddit Trending (Social Sentiment)
try:
reddit_report = execute_tool("get_trending_tickers", limit=self.reddit_trending_limit)
@ -188,7 +248,73 @@ Data:
except Exception as e:
print(f" Error fetching Market Movers: {e}")
# 3. Get Earnings Calendar (Event-based Discovery)
try:
from datetime import datetime, timedelta
today = datetime.now()
from_date = today.strftime("%Y-%m-%d")
to_date = (today + timedelta(days=7)).strftime("%Y-%m-%d") # Next 7 days
earnings_report = execute_tool("get_earnings_calendar", from_date=from_date, to_date=to_date)
# Extract tickers from earnings calendar
prompt = """Extract ONLY valid stock ticker symbols from this earnings calendar.
Return a comma-separated list of tickers (1-5 uppercase letters).
Only include actual stock tickers, not indexes or other symbols.
Earnings Calendar:
{report}
Return a JSON object with a 'tickers' array containing only valid stock ticker symbols.""".format(report=earnings_report)
structured_llm = self.quick_thinking_llm.with_structured_output(
schema=TickerList.model_json_schema(),
method="json_schema"
)
response = structured_llm.invoke([HumanMessage(content=prompt)])
earnings_tickers = response.get("tickers", [])
for t in earnings_tickers:
t = t.upper().strip()
if re.match(r'^[A-Z]{1,5}$', t):
candidates.append({"ticker": t, "source": "earnings_catalyst", "sentiment": "unknown"})
except Exception as e:
print(f" Error fetching Earnings Calendar: {e}")
# 4. Get IPO Calendar (New Listings Discovery)
try:
from datetime import datetime, timedelta
today = datetime.now()
from_date = (today - timedelta(days=7)).strftime("%Y-%m-%d") # Past 7 days
to_date = (today + timedelta(days=14)).strftime("%Y-%m-%d") # Next 14 days
ipo_report = execute_tool("get_ipo_calendar", from_date=from_date, to_date=to_date)
# Extract tickers from IPO calendar
prompt = """Extract ONLY valid stock ticker symbols from this IPO calendar.
Return a comma-separated list of tickers (1-5 uppercase letters).
Only include actual stock tickers that are listed or about to be listed.
IPO Calendar:
{report}
Return a JSON object with a 'tickers' array containing only valid stock ticker symbols.""".format(report=ipo_report)
structured_llm = self.quick_thinking_llm.with_structured_output(
schema=TickerList.model_json_schema(),
method="json_schema"
)
response = structured_llm.invoke([HumanMessage(content=prompt)])
ipo_tickers = response.get("tickers", [])
for t in ipo_tickers:
t = t.upper().strip()
if re.match(r'^[A-Z]{1,5}$', t):
candidates.append({"ticker": t, "source": "ipo_listing", "sentiment": "unknown"})
except Exception as e:
print(f" Error fetching IPO Calendar: {e}")
# Deduplicate
unique_candidates = {}
for c in candidates:
@ -232,8 +358,30 @@ Data:
strategy = "contrarian_value"
elif source == "social_trending" or source == "twitter_sentiment":
strategy = "social_hype"
elif source == "earnings_catalyst":
strategy = "earnings_play"
elif source == "ipo_listing":
strategy = "ipo_opportunity"
cand['strategy'] = strategy
# Technical Analysis Check (New)
try:
from datetime import datetime
today = datetime.now().strftime("%Y-%m-%d")
# Get RSI
rsi_data = execute_tool("get_indicators", symbol=ticker, indicator="rsi", curr_date=today, look_back_days=14)
# Simple parsing of the string report to find the latest value
# The report format is usually "## rsi values...\n\nDATE: VALUE"
# We'll just store the report for the LLM to analyze in deep dive if needed,
# OR we can try to parse it here. For now, let's just add it to metadata.
cand['technical_indicators'] = rsi_data
except Exception as e:
print(f" Error getting technicals for {ticker}: {e}")
filtered_candidates.append(cand)
except Exception as e:
@ -276,8 +424,9 @@ Data:
# 1. Get News Sentiment
news = execute_tool("get_news", ticker=ticker, start_date=start_date, end_date=end_date)
# 2. Get Insider Transactions
# 2. Get Insider Transactions & Sentiment
insider = execute_tool("get_insider_transactions", ticker=ticker)
insider_sentiment = execute_tool("get_insider_sentiment", ticker=ticker)
# 3. Get Fundamentals (for the Contrarian check)
fundamentals = execute_tool("get_fundamentals", ticker=ticker, curr_date=end_date)
@ -290,6 +439,7 @@ Data:
"strategy": strategy,
"news": news,
"insider_transactions": insider,
"insider_sentiment": insider_sentiment,
"fundamentals": fundamentals,
"recommendations": recommendations
})
@ -313,6 +463,7 @@ Data:
"strategy": opp["strategy"],
# Truncate to ~1000 chars each (roughly 250 tokens)
"news": opp["news"][:1000] + "..." if len(opp["news"]) > 1000 else opp["news"],
"insider_sentiment": opp.get("insider_sentiment", "")[:500],
"insider_transactions": opp["insider_transactions"][:1000] + "..." if len(opp["insider_transactions"]) > 1000 else opp["insider_transactions"],
"fundamentals": opp["fundamentals"][:1000] + "..." if len(opp["fundamentals"]) > 1000 else opp["fundamentals"],
"recommendations": opp["recommendations"][:1000] + "..." if len(opp["recommendations"]) > 1000 else opp["recommendations"],

View File

@ -3,6 +3,7 @@
from .llm_outputs import (
TradeDecision,
TickerList,
ThemeList,
MarketMover,
MarketMovers,
InvestmentOpportunity,
@ -14,6 +15,7 @@ from .llm_outputs import (
__all__ = [
"TradeDecision",
"TickerList",
"ThemeList",
"MarketMovers",
"MarketMover",
"InvestmentOpportunity",

View File

@ -26,6 +26,7 @@ class TradeDecision(BaseModel):
)
class TickerList(BaseModel):
"""Structured output for ticker symbol lists."""
@ -34,6 +35,14 @@ class TickerList(BaseModel):
)
class ThemeList(BaseModel):
"""Structured output for market themes."""
themes: List[str] = Field(
description="List of trending market themes or sectors"
)
class MarketMover(BaseModel):
"""Individual market mover entry."""

View File

@ -6,14 +6,16 @@ registry-based approach. All routing decisions are driven by the tool registry.
Key improvements over old system:
- Single registry lookup instead of multiple dictionary lookups
- Clear fallback logic per tool (optional, not mandatory)
- Supports both fallback and aggregate execution modes
- Parallel vendor execution for aggregate mode
- Better error messages and debugging
- No dual registry systems
"""
from typing import Any, Optional, List
from typing import Any, Optional, List, Dict
import logging
from tradingagents.tools.registry import TOOL_REGISTRY, get_vendor_config
import concurrent.futures
from tradingagents.tools.registry import TOOL_REGISTRY, get_vendor_config, get_tool_metadata
logger = logging.getLogger(__name__)
@ -28,43 +30,30 @@ class VendorNotFoundError(Exception):
pass
def execute_tool(tool_name: str, *args, **kwargs) -> Any:
"""Execute a tool using the registry-based routing system.
def _execute_fallback(tool_name: str, vendor_config: Dict, *args, **kwargs) -> Any:
"""Execute vendors sequentially with fallback (original behavior).
This is the simplified replacement for route_to_vendor().
Tries vendors in priority order and returns the first successful result.
Args:
tool_name: Name of the tool to execute (e.g., "get_stock_data")
*args: Positional arguments to pass to the tool
**kwargs: Keyword arguments to pass to the tool
tool_name: Name of the tool
vendor_config: Vendor configuration from registry
*args: Positional arguments for vendor function
**kwargs: Keyword arguments for vendor function
Returns:
Result from the vendor function
Result from first successful vendor
Raises:
VendorNotFoundError: If tool or vendor implementation not found
ToolExecutionError: If all vendors fail to execute the tool
ToolExecutionError: If all vendors fail
"""
# Step 1: Get vendor configuration from registry
vendor_config = get_vendor_config(tool_name)
if not vendor_config["vendor_priority"]:
raise VendorNotFoundError(
f"Tool '{tool_name}' not found in registry or has no vendors configured"
)
# Step 2: Get vendor functions and priority list
vendor_functions = vendor_config["vendors"]
vendors_to_try = vendor_config["vendor_priority"]
logger.debug(f"Executing tool '{tool_name}' with vendors: {vendors_to_try}")
# Step 3: Try each vendor in priority order
errors = []
logger.debug(f"Executing tool '{tool_name}' in fallback mode with vendors: {vendors_to_try}")
for vendor_name in vendors_to_try:
# Get the vendor function directly from registry
vendor_func = vendor_functions.get(vendor_name)
if not vendor_func:
@ -80,16 +69,131 @@ def execute_tool(tool_name: str, *args, **kwargs) -> Any:
error_msg = f"Vendor '{vendor_name}' failed: {str(e)}"
logger.warning(f"Tool '{tool_name}': {error_msg}")
errors.append(error_msg)
# Continue to next vendor (fallback)
continue
# Step 4: All vendors failed
# All vendors failed
error_summary = f"Tool '{tool_name}' failed with all vendors:\n" + "\n".join(f" - {err}" for err in errors)
logger.error(error_summary)
raise ToolExecutionError(error_summary)
def _execute_aggregate(tool_name: str, vendor_config: Dict, metadata: Dict, *args, **kwargs) -> str:
"""Execute multiple vendors in parallel and aggregate results.
Executes all specified vendors simultaneously using ThreadPoolExecutor,
collects successful results, and combines them with vendor labels.
Args:
tool_name: Name of the tool
vendor_config: Vendor configuration from registry
metadata: Tool metadata from registry
*args: Positional arguments for vendor functions
**kwargs: Keyword arguments for vendor functions
Returns:
Aggregated results from all successful vendors, formatted with labels
Raises:
ToolExecutionError: If all vendors fail
"""
vendor_functions = vendor_config["vendors"]
# Get list of vendors to aggregate (default to all in priority list)
vendors_to_aggregate = metadata.get("aggregate_vendors") or vendor_config["vendor_priority"]
logger.debug(f"Executing tool '{tool_name}' in aggregate mode with vendors: {vendors_to_aggregate}")
results = []
errors = []
# Execute vendors in parallel using ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor(max_workers=len(vendors_to_aggregate)) as executor:
# Submit all vendor calls
future_to_vendor = {}
for vendor_name in vendors_to_aggregate:
vendor_func = vendor_functions.get(vendor_name)
if vendor_func:
future = executor.submit(vendor_func, *args, **kwargs)
future_to_vendor[future] = vendor_name
else:
logger.warning(f"Vendor '{vendor_name}' not found in vendors dict for tool '{tool_name}'")
# Collect results as they complete
for future in concurrent.futures.as_completed(future_to_vendor):
vendor_name = future_to_vendor[future]
try:
result = future.result()
results.append({
"vendor": vendor_name,
"data": result
})
logger.debug(f"Tool '{tool_name}': vendor '{vendor_name}' succeeded")
except Exception as e:
error_msg = f"Vendor '{vendor_name}' failed: {str(e)}"
errors.append(error_msg)
logger.warning(f"Tool '{tool_name}': {error_msg}")
# Check if we got any results
if not results:
error_summary = f"Tool '{tool_name}' aggregate mode: all vendors failed:\n" + "\n".join(f" - {err}" for err in errors)
logger.error(error_summary)
raise ToolExecutionError(error_summary)
# Format aggregated results with clear vendor labels
formatted_results = []
for item in results:
vendor_label = f"=== {item['vendor'].upper()} ==="
formatted_results.append(f"{vendor_label}\n{item['data']}")
# Log partial success if some vendors failed
if errors:
logger.info(f"Tool '{tool_name}': {len(results)} vendors succeeded, {len(errors)} failed")
return "\n\n".join(formatted_results)
def execute_tool(tool_name: str, *args, **kwargs) -> Any:
"""Execute a tool using fallback or aggregate mode based on configuration.
This is the main entry point for tool execution. It dispatches to either
fallback mode (sequential with early return) or aggregate mode (parallel
with result combination) based on the tool's execution_mode setting.
Args:
tool_name: Name of the tool to execute (e.g., "get_stock_data")
*args: Positional arguments to pass to the tool
**kwargs: Keyword arguments to pass to the tool
Returns:
Result from vendor function(s). String for aggregate mode (formatted
with vendor labels), Any for fallback mode (raw vendor result).
Raises:
VendorNotFoundError: If tool or vendor implementation not found
ToolExecutionError: If all vendors fail to execute the tool
"""
# Get vendor configuration and metadata from registry
vendor_config = get_vendor_config(tool_name)
metadata = get_tool_metadata(tool_name)
if not vendor_config["vendor_priority"]:
raise VendorNotFoundError(
f"Tool '{tool_name}' not found in registry or has no vendors configured"
)
if not metadata:
raise VendorNotFoundError(f"Tool '{tool_name}' metadata not found in registry")
# Check execution mode (defaults to fallback for backward compatibility)
execution_mode = metadata.get("execution_mode", "fallback")
# Dispatch to appropriate execution strategy
if execution_mode == "aggregate":
return _execute_aggregate(tool_name, vendor_config, metadata, *args, **kwargs)
else:
return _execute_fallback(tool_name, vendor_config, *args, **kwargs)
def get_tool_info(tool_name: str) -> Optional[dict]:
"""Get information about a tool from the registry.

View File

@ -54,6 +54,10 @@ def generate_langchain_tool(tool_name: str, metadata: Dict[str, Any]) -> Callabl
# Use **kwargs to handle all parameters
def tool_function(**kwargs):
"""Dynamically generated tool function."""
# Ensure defaults are applied for missing parameters
for param_name, param_info in parameters.items():
if param_name not in kwargs and "default" in param_info:
kwargs[param_name] = param_info["default"]
return execute_tool(tool_name, **kwargs)
# Set function metadata

View File

@ -53,6 +53,8 @@ from tradingagents.dataflows.reddit_api import (
)
from tradingagents.dataflows.finnhub_api import (
get_recommendation_trends as get_finnhub_recommendation_trends,
get_earnings_calendar as get_finnhub_earnings_calendar,
get_ipo_calendar as get_finnhub_ipo_calendar,
)
from tradingagents.dataflows.twitter_data import (
get_tweets as get_twitter_tweets,
@ -109,6 +111,8 @@ TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {
"alpha_vantage": get_alpha_vantage_indicator,
},
"vendor_priority": ["yfinance", "alpha_vantage"],
"execution_mode": "aggregate",
"aggregate_vendors": ["yfinance", "alpha_vantage"],
"parameters": {
"symbol": {"type": "str", "description": "Ticker symbol"},
"indicator": {"type": "str", "description": "Technical indicator (rsi, macd, sma, ema, etc.)"},
@ -208,6 +212,8 @@ TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {
"google": get_google_news,
},
"vendor_priority": ["alpha_vantage", "reddit", "openai", "google"],
"execution_mode": "aggregate",
"aggregate_vendors": ["alpha_vantage", "reddit", "google"],
"parameters": {
"query": {"type": "str", "description": "Search query or ticker symbol"},
"start_date": {"type": "str", "description": "Start date, yyyy-mm-dd"},
@ -227,6 +233,7 @@ TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {
"alpha_vantage": get_alpha_vantage_global_news,
},
"vendor_priority": ["openai", "google", "reddit", "alpha_vantage"],
"execution_mode": "aggregate",
"parameters": {
"date": {"type": "str", "description": "Date for news, yyyy-mm-dd"},
"look_back_days": {"type": "int", "description": "Days to look back", "default": 7},
@ -310,6 +317,36 @@ TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {
"returns": "str: Tweets matching the query",
},
"get_earnings_calendar": {
"description": "Get upcoming earnings announcements (catalysts for volatility)",
"category": "discovery",
"agents": [],
"vendors": {
"finnhub": get_finnhub_earnings_calendar,
},
"vendor_priority": ["finnhub"],
"parameters": {
"from_date": {"type": "str", "description": "Start date in yyyy-mm-dd format"},
"to_date": {"type": "str", "description": "End date in yyyy-mm-dd format"},
},
"returns": "str: Formatted earnings calendar with EPS and revenue estimates",
},
"get_ipo_calendar": {
"description": "Get upcoming and recent IPOs (new listing opportunities)",
"category": "discovery",
"agents": [],
"vendors": {
"finnhub": get_finnhub_ipo_calendar,
},
"vendor_priority": ["finnhub"],
"parameters": {
"from_date": {"type": "str", "description": "Start date in yyyy-mm-dd format"},
"to_date": {"type": "str", "description": "End date in yyyy-mm-dd format"},
},
"returns": "str: Formatted IPO calendar with pricing and share details",
},
"get_reddit_discussions": {
"description": "Get Reddit discussions about a specific ticker",
"category": "news_data",
@ -465,6 +502,26 @@ def validate_registry() -> List[str]:
if not isinstance(metadata.get("parameters"), dict):
issues.append(f"{tool_name}: Parameters must be a dictionary")
# Validate execution_mode if present
if "execution_mode" in metadata:
execution_mode = metadata["execution_mode"]
if execution_mode not in ["fallback", "aggregate"]:
issues.append(f"{tool_name}: Invalid execution_mode '{execution_mode}', must be 'fallback' or 'aggregate'")
# Validate aggregate_vendors if present
if "aggregate_vendors" in metadata:
aggregate_vendors = metadata["aggregate_vendors"]
if not isinstance(aggregate_vendors, list):
issues.append(f"{tool_name}: aggregate_vendors must be a list")
else:
for vendor_name in aggregate_vendors:
if vendor_name not in vendors:
issues.append(f"{tool_name}: aggregate_vendor '{vendor_name}' not in vendors dict")
# Warn if aggregate_vendors specified but execution_mode is not aggregate
if metadata.get("execution_mode") != "aggregate":
issues.append(f"{tool_name}: aggregate_vendors specified but execution_mode is not 'aggregate'")
return issues