diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index 48a9fdce..60b3434f 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -39,7 +39,7 @@ Volatility Indicators: Volume-Based Indicators: - vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses. -- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions.""" +- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then call get_indicators SEPARATELY for EACH indicator you want to analyze (e.g., call get_indicators once with indicator="rsi", then call it again with indicator="macd", etc.). Do NOT pass multiple indicators in a single call. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions.""" + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" ) diff --git a/tradingagents/dataflows/alpha_vantage_news.py b/tradingagents/dataflows/alpha_vantage_news.py index 0e199fec..8002735e 100644 --- a/tradingagents/dataflows/alpha_vantage_news.py +++ b/tradingagents/dataflows/alpha_vantage_news.py @@ -1,6 +1,7 @@ +from typing import Union, Dict, Optional from .alpha_vantage_common import _make_api_request, format_datetime_for_api -def get_news(ticker: str = None, start_date: str = None, end_date: str = None, query: str = None) -> dict[str, str] | str: +def get_news(ticker: str = None, start_date: str = None, end_date: str = None, query: str = None) -> Union[Dict[str, str], str]: """Returns live and historical market news & sentiment data. Args: @@ -28,7 +29,7 @@ def get_news(ticker: str = None, start_date: str = None, end_date: str = None, q return _make_api_request("NEWS_SENTIMENT", params) -def get_global_news(date: str, look_back_days: int = 7, limit: int = 5) -> dict[str, str] | str: +def get_global_news(date: str, look_back_days: int = 7, limit: int = 5) -> Union[Dict[str, str], str]: """Returns global market news & sentiment data. Args: @@ -48,7 +49,7 @@ def get_global_news(date: str, look_back_days: int = 7, limit: int = 5) -> dict[ return _make_api_request("NEWS_SENTIMENT", params) -def get_insider_transactions(symbol: str = None, ticker: str = None, curr_date: str = None) -> dict[str, str] | str: +def get_insider_transactions(symbol: str = None, ticker: str = None, curr_date: str = None) -> Union[Dict[str, str], str]: """Returns latest and historical insider transactions. Args: diff --git a/tradingagents/dataflows/finnhub_api.py b/tradingagents/dataflows/finnhub_api.py index 417f55e6..607a6d7b 100644 --- a/tradingagents/dataflows/finnhub_api.py +++ b/tradingagents/dataflows/finnhub_api.py @@ -63,3 +63,157 @@ def get_recommendation_trends( except Exception as e: return f"Error fetching recommendation trends for {ticker}: {str(e)}" + + +def get_earnings_calendar( + from_date: Annotated[str, "Start date in yyyy-mm-dd format"], + to_date: Annotated[str, "End date in yyyy-mm-dd format"] +) -> str: + """ + Get earnings calendar for stocks with upcoming earnings announcements. + + Args: + from_date: Start date in yyyy-mm-dd format + to_date: End date in yyyy-mm-dd format + + Returns: + str: Formatted report of upcoming earnings + """ + try: + client = get_finnhub_client() + data = client.earnings_calendar( + _from=from_date, + to=to_date, + symbol="", # Empty string returns all stocks + international=False + ) + + if not data or 'earningsCalendar' not in data: + return f"No earnings data found for period {from_date} to {to_date}" + + earnings = data['earningsCalendar'] + + if not earnings: + return f"No earnings scheduled between {from_date} and {to_date}" + + # Format the response + result = f"## Earnings Calendar ({from_date} to {to_date})\n\n" + result += f"**Total Companies**: {len(earnings)}\n\n" + + # Group by date + by_date = {} + for entry in earnings: + date = entry.get('date', 'Unknown') + if date not in by_date: + by_date[date] = [] + by_date[date].append(entry) + + # Format by date + for date in sorted(by_date.keys()): + result += f"### {date}\n\n" + + for entry in by_date[date]: + symbol = entry.get('symbol', 'N/A') + eps_estimate = entry.get('epsEstimate', 'N/A') + eps_actual = entry.get('epsActual', 'N/A') + revenue_estimate = entry.get('revenueEstimate', 'N/A') + revenue_actual = entry.get('revenueActual', 'N/A') + hour = entry.get('hour', 'N/A') + + result += f"**{symbol}**" + if hour != 'N/A': + result += f" ({hour})" + result += "\n" + + if eps_estimate != 'N/A': + result += f" - EPS Estimate: ${eps_estimate:.2f}" if isinstance(eps_estimate, (int, float)) else f" - EPS Estimate: {eps_estimate}" + if eps_actual != 'N/A': + result += f" | Actual: ${eps_actual:.2f}" if isinstance(eps_actual, (int, float)) else f" | Actual: {eps_actual}" + result += "\n" + + if revenue_estimate != 'N/A': + result += f" - Revenue Estimate: ${revenue_estimate:,.0f}M" if isinstance(revenue_estimate, (int, float)) else f" - Revenue Estimate: {revenue_estimate}" + if revenue_actual != 'N/A': + result += f" | Actual: ${revenue_actual:,.0f}M" if isinstance(revenue_actual, (int, float)) else f" | Actual: {revenue_actual}" + result += "\n" + + result += "\n" + + return result + + except Exception as e: + return f"Error fetching earnings calendar: {str(e)}" + + +def get_ipo_calendar( + from_date: Annotated[str, "Start date in yyyy-mm-dd format"], + to_date: Annotated[str, "End date in yyyy-mm-dd format"] +) -> str: + """ + Get IPO calendar for upcoming and recent initial public offerings. + + Args: + from_date: Start date in yyyy-mm-dd format + to_date: End date in yyyy-mm-dd format + + Returns: + str: Formatted report of IPOs + """ + try: + client = get_finnhub_client() + data = client.ipo_calendar( + _from=from_date, + to=to_date + ) + + if not data or 'ipoCalendar' not in data: + return f"No IPO data found for period {from_date} to {to_date}" + + ipos = data['ipoCalendar'] + + if not ipos: + return f"No IPOs scheduled between {from_date} and {to_date}" + + # Format the response + result = f"## IPO Calendar ({from_date} to {to_date})\n\n" + result += f"**Total IPOs**: {len(ipos)}\n\n" + + # Group by date + by_date = {} + for entry in ipos: + date = entry.get('date', 'Unknown') + if date not in by_date: + by_date[date] = [] + by_date[date].append(entry) + + # Format by date + for date in sorted(by_date.keys()): + result += f"### {date}\n\n" + + for entry in by_date[date]: + symbol = entry.get('symbol', 'N/A') + name = entry.get('name', 'N/A') + exchange = entry.get('exchange', 'N/A') + price = entry.get('price', 'N/A') + shares = entry.get('numberOfShares', 'N/A') + total_shares = entry.get('totalSharesValue', 'N/A') + status = entry.get('status', 'N/A') + + result += f"**{symbol}** - {name}\n" + result += f" - Exchange: {exchange}\n" + + if price != 'N/A': + result += f" - Price: ${price}\n" + + if shares != 'N/A': + result += f" - Shares Offered: {shares:,}\n" if isinstance(shares, (int, float)) else f" - Shares Offered: {shares}\n" + + if total_shares != 'N/A': + result += f" - Total Value: ${total_shares:,.0f}M\n" if isinstance(total_shares, (int, float)) else f" - Total Value: {total_shares}\n" + + result += f" - Status: {status}\n\n" + + return result + + except Exception as e: + return f"Error fetching IPO calendar: {str(e)}" diff --git a/tradingagents/dataflows/google.py b/tradingagents/dataflows/google.py index 322e1114..975b9f2f 100644 --- a/tradingagents/dataflows/google.py +++ b/tradingagents/dataflows/google.py @@ -5,13 +5,23 @@ from .googlenews_utils import getNewsData def get_google_news( - query: Annotated[str, "Query to search with"], + query: Annotated[str, "Query to search with"] = None, + ticker: Annotated[str, "Ticker symbol (alias for query)"] = None, curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"] = None, look_back_days: Annotated[int, "how many days to look back"] = None, start_date: Annotated[str, "Start date in yyyy-mm-dd format"] = None, end_date: Annotated[str, "End date in yyyy-mm-dd format"] = None, ) -> str: - query = query.replace(" ", "+") + # Handle parameter aliasing (query or ticker) + if query: + search_query = query + elif ticker: + # Format ticker as a natural language query for better results + search_query = f"latest news on {ticker} stock" + else: + raise ValueError("Must provide either 'query' or 'ticker' parameter") + + search_query = search_query.replace(" ", "+") # Determine date range if start_date and end_date: @@ -24,7 +34,7 @@ def get_google_news( else: raise ValueError("Must provide either (start_date, end_date) or (curr_date, look_back_days)") - news_results = getNewsData(query, before, target_date) + news_results = getNewsData(search_query, before, target_date) news_str = "" @@ -36,7 +46,7 @@ def get_google_news( if len(news_results) == 0: return "" - return f"## {query} Google News, from {before} to {target_date}:\n\n{news_str}" + return f"## {search_query} Google News, from {before} to {target_date}:\n\n{news_str}" def get_global_news_google( diff --git a/tradingagents/dataflows/openai.py b/tradingagents/dataflows/openai.py index 9cc5b6e1..2a07a75b 100644 --- a/tradingagents/dataflows/openai.py +++ b/tradingagents/dataflows/openai.py @@ -2,7 +2,24 @@ from openai import OpenAI from .config import get_config -def get_stock_news_openai(query, start_date, end_date): +def get_stock_news_openai(query=None, ticker=None, start_date=None, end_date=None): + """Get stock news from OpenAI web search. + + Args: + query: Search query or ticker symbol + ticker: Ticker symbol (alias for query) + start_date: Start date yyyy-mm-dd + end_date: End date yyyy-mm-dd + """ + # Handle parameter aliasing + if query: + search_query = query + elif ticker: + # Format ticker as a natural language query for better results + search_query = f"latest news and market analysis on {ticker} stock" + else: + raise ValueError("Must provide either 'query' or 'ticker' parameter") + config = get_config() client = OpenAI(base_url=config["backend_url"]) @@ -10,7 +27,7 @@ def get_stock_news_openai(query, start_date, end_date): response = client.responses.create( model="gpt-4o-mini", tools=[{"type": "web_search_preview"}], - input=f"Search Social Media for {query} from {start_date} to {end_date}. Make sure you only get the data posted during that period." + input=f"Search Social Media and news sources for {search_query} from {start_date} to {end_date}. Make sure you only get the data posted during that period." ) return response.output_text except Exception as e: diff --git a/tradingagents/graph/discovery_graph.py b/tradingagents/graph/discovery_graph.py index c6852c28..2200fbfd 100644 --- a/tradingagents/graph/discovery_graph.py +++ b/tradingagents/graph/discovery_graph.py @@ -11,7 +11,7 @@ from tradingagents.agents.utils.agent_utils import ( get_indicators ) from tradingagents.tools.executor import execute_tool -from tradingagents.schemas import TickerList, MarketMovers +from tradingagents.schemas import TickerList, MarketMovers, ThemeList class DiscoveryGraph: def __init__(self, config=None): @@ -76,6 +76,66 @@ class DiscoveryGraph: candidates = [] + # 0. Macro Theme Discovery (Top-Down) + try: + from datetime import datetime + today = datetime.now().strftime("%Y-%m-%d") + + # Get Global News + global_news = execute_tool("get_global_news", date=today, limit=5) + + # Extract Themes + prompt = f"""Based on this global news, identify 3 trending market themes or sectors (e.g., 'Artificial Intelligence', 'Oil', 'Biotech'). + Return a JSON object with a 'themes' array of strings. + + News: + {global_news} + """ + + structured_llm = self.quick_thinking_llm.with_structured_output( + schema=ThemeList.model_json_schema(), + method="json_schema" + ) + response = structured_llm.invoke([HumanMessage(content=prompt)]) + themes = response.get("themes", []) + + print(f" Identified Macro Themes: {themes}") + + # Find tickers for each theme + for theme in themes: + try: + tweets_report = execute_tool("get_tweets", query=f"{theme} stocks", count=15) + + prompt = f"""Extract ONLY valid stock ticker symbols related to the theme '{theme}' from this report. + Return a comma-separated list of tickers (1-5 uppercase letters). + + Report: + {tweets_report} + + Return a JSON object with a 'tickers' array.""" + + structured_llm = self.quick_thinking_llm.with_structured_output( + schema=TickerList.model_json_schema(), + method="json_schema" + ) + response = structured_llm.invoke([HumanMessage(content=prompt)]) + theme_tickers = response.get("tickers", []) + + for t in theme_tickers: + t = t.upper().strip() + if re.match(r'^[A-Z]{1,5}$', t): + # Use validate_ticker tool logic (via execute_tool) + try: + if execute_tool("validate_ticker", symbol=t): + candidates.append({"ticker": t, "source": f"macro_theme_{theme}", "sentiment": "unknown"}) + except Exception: + continue + except Exception as e: + print(f" Error fetching tickers for theme {theme}: {e}") + + except Exception as e: + print(f" Error in Macro Theme Discovery: {e}") + # 1. Get Reddit Trending (Social Sentiment) try: reddit_report = execute_tool("get_trending_tickers", limit=self.reddit_trending_limit) @@ -188,7 +248,73 @@ Data: except Exception as e: print(f" Error fetching Market Movers: {e}") - + + # 3. Get Earnings Calendar (Event-based Discovery) + try: + from datetime import datetime, timedelta + today = datetime.now() + from_date = today.strftime("%Y-%m-%d") + to_date = (today + timedelta(days=7)).strftime("%Y-%m-%d") # Next 7 days + + earnings_report = execute_tool("get_earnings_calendar", from_date=from_date, to_date=to_date) + + # Extract tickers from earnings calendar + prompt = """Extract ONLY valid stock ticker symbols from this earnings calendar. +Return a comma-separated list of tickers (1-5 uppercase letters). +Only include actual stock tickers, not indexes or other symbols. + +Earnings Calendar: +{report} + +Return a JSON object with a 'tickers' array containing only valid stock ticker symbols.""".format(report=earnings_report) + + structured_llm = self.quick_thinking_llm.with_structured_output( + schema=TickerList.model_json_schema(), + method="json_schema" + ) + response = structured_llm.invoke([HumanMessage(content=prompt)]) + + earnings_tickers = response.get("tickers", []) + for t in earnings_tickers: + t = t.upper().strip() + if re.match(r'^[A-Z]{1,5}$', t): + candidates.append({"ticker": t, "source": "earnings_catalyst", "sentiment": "unknown"}) + except Exception as e: + print(f" Error fetching Earnings Calendar: {e}") + + # 4. Get IPO Calendar (New Listings Discovery) + try: + from datetime import datetime, timedelta + today = datetime.now() + from_date = (today - timedelta(days=7)).strftime("%Y-%m-%d") # Past 7 days + to_date = (today + timedelta(days=14)).strftime("%Y-%m-%d") # Next 14 days + + ipo_report = execute_tool("get_ipo_calendar", from_date=from_date, to_date=to_date) + + # Extract tickers from IPO calendar + prompt = """Extract ONLY valid stock ticker symbols from this IPO calendar. +Return a comma-separated list of tickers (1-5 uppercase letters). +Only include actual stock tickers that are listed or about to be listed. + +IPO Calendar: +{report} + +Return a JSON object with a 'tickers' array containing only valid stock ticker symbols.""".format(report=ipo_report) + + structured_llm = self.quick_thinking_llm.with_structured_output( + schema=TickerList.model_json_schema(), + method="json_schema" + ) + response = structured_llm.invoke([HumanMessage(content=prompt)]) + + ipo_tickers = response.get("tickers", []) + for t in ipo_tickers: + t = t.upper().strip() + if re.match(r'^[A-Z]{1,5}$', t): + candidates.append({"ticker": t, "source": "ipo_listing", "sentiment": "unknown"}) + except Exception as e: + print(f" Error fetching IPO Calendar: {e}") + # Deduplicate unique_candidates = {} for c in candidates: @@ -232,8 +358,30 @@ Data: strategy = "contrarian_value" elif source == "social_trending" or source == "twitter_sentiment": strategy = "social_hype" + elif source == "earnings_catalyst": + strategy = "earnings_play" + elif source == "ipo_listing": + strategy = "ipo_opportunity" cand['strategy'] = strategy + + # Technical Analysis Check (New) + try: + from datetime import datetime + today = datetime.now().strftime("%Y-%m-%d") + + # Get RSI + rsi_data = execute_tool("get_indicators", symbol=ticker, indicator="rsi", curr_date=today, look_back_days=14) + + # Simple parsing of the string report to find the latest value + # The report format is usually "## rsi values...\n\nDATE: VALUE" + # We'll just store the report for the LLM to analyze in deep dive if needed, + # OR we can try to parse it here. For now, let's just add it to metadata. + cand['technical_indicators'] = rsi_data + + except Exception as e: + print(f" Error getting technicals for {ticker}: {e}") + filtered_candidates.append(cand) except Exception as e: @@ -276,8 +424,9 @@ Data: # 1. Get News Sentiment news = execute_tool("get_news", ticker=ticker, start_date=start_date, end_date=end_date) - # 2. Get Insider Transactions + # 2. Get Insider Transactions & Sentiment insider = execute_tool("get_insider_transactions", ticker=ticker) + insider_sentiment = execute_tool("get_insider_sentiment", ticker=ticker) # 3. Get Fundamentals (for the Contrarian check) fundamentals = execute_tool("get_fundamentals", ticker=ticker, curr_date=end_date) @@ -290,6 +439,7 @@ Data: "strategy": strategy, "news": news, "insider_transactions": insider, + "insider_sentiment": insider_sentiment, "fundamentals": fundamentals, "recommendations": recommendations }) @@ -313,6 +463,7 @@ Data: "strategy": opp["strategy"], # Truncate to ~1000 chars each (roughly 250 tokens) "news": opp["news"][:1000] + "..." if len(opp["news"]) > 1000 else opp["news"], + "insider_sentiment": opp.get("insider_sentiment", "")[:500], "insider_transactions": opp["insider_transactions"][:1000] + "..." if len(opp["insider_transactions"]) > 1000 else opp["insider_transactions"], "fundamentals": opp["fundamentals"][:1000] + "..." if len(opp["fundamentals"]) > 1000 else opp["fundamentals"], "recommendations": opp["recommendations"][:1000] + "..." if len(opp["recommendations"]) > 1000 else opp["recommendations"], diff --git a/tradingagents/schemas/__init__.py b/tradingagents/schemas/__init__.py index feff8da5..0782d204 100644 --- a/tradingagents/schemas/__init__.py +++ b/tradingagents/schemas/__init__.py @@ -3,6 +3,7 @@ from .llm_outputs import ( TradeDecision, TickerList, + ThemeList, MarketMover, MarketMovers, InvestmentOpportunity, @@ -14,6 +15,7 @@ from .llm_outputs import ( __all__ = [ "TradeDecision", "TickerList", + "ThemeList", "MarketMovers", "MarketMover", "InvestmentOpportunity", diff --git a/tradingagents/schemas/llm_outputs.py b/tradingagents/schemas/llm_outputs.py index 4a97fd57..73dec960 100644 --- a/tradingagents/schemas/llm_outputs.py +++ b/tradingagents/schemas/llm_outputs.py @@ -26,6 +26,7 @@ class TradeDecision(BaseModel): ) + class TickerList(BaseModel): """Structured output for ticker symbol lists.""" @@ -34,6 +35,14 @@ class TickerList(BaseModel): ) +class ThemeList(BaseModel): + """Structured output for market themes.""" + + themes: List[str] = Field( + description="List of trending market themes or sectors" + ) + + class MarketMover(BaseModel): """Individual market mover entry.""" diff --git a/tradingagents/tools/executor.py b/tradingagents/tools/executor.py index 40533999..8a7ef07f 100644 --- a/tradingagents/tools/executor.py +++ b/tradingagents/tools/executor.py @@ -6,14 +6,16 @@ registry-based approach. All routing decisions are driven by the tool registry. Key improvements over old system: - Single registry lookup instead of multiple dictionary lookups -- Clear fallback logic per tool (optional, not mandatory) +- Supports both fallback and aggregate execution modes +- Parallel vendor execution for aggregate mode - Better error messages and debugging - No dual registry systems """ -from typing import Any, Optional, List +from typing import Any, Optional, List, Dict import logging -from tradingagents.tools.registry import TOOL_REGISTRY, get_vendor_config +import concurrent.futures +from tradingagents.tools.registry import TOOL_REGISTRY, get_vendor_config, get_tool_metadata logger = logging.getLogger(__name__) @@ -28,43 +30,30 @@ class VendorNotFoundError(Exception): pass -def execute_tool(tool_name: str, *args, **kwargs) -> Any: - """Execute a tool using the registry-based routing system. +def _execute_fallback(tool_name: str, vendor_config: Dict, *args, **kwargs) -> Any: + """Execute vendors sequentially with fallback (original behavior). - This is the simplified replacement for route_to_vendor(). + Tries vendors in priority order and returns the first successful result. Args: - tool_name: Name of the tool to execute (e.g., "get_stock_data") - *args: Positional arguments to pass to the tool - **kwargs: Keyword arguments to pass to the tool + tool_name: Name of the tool + vendor_config: Vendor configuration from registry + *args: Positional arguments for vendor function + **kwargs: Keyword arguments for vendor function Returns: - Result from the vendor function + Result from first successful vendor Raises: - VendorNotFoundError: If tool or vendor implementation not found - ToolExecutionError: If all vendors fail to execute the tool + ToolExecutionError: If all vendors fail """ - - # Step 1: Get vendor configuration from registry - vendor_config = get_vendor_config(tool_name) - - if not vendor_config["vendor_priority"]: - raise VendorNotFoundError( - f"Tool '{tool_name}' not found in registry or has no vendors configured" - ) - - # Step 2: Get vendor functions and priority list vendor_functions = vendor_config["vendors"] vendors_to_try = vendor_config["vendor_priority"] - - logger.debug(f"Executing tool '{tool_name}' with vendors: {vendors_to_try}") - - # Step 3: Try each vendor in priority order errors = [] + logger.debug(f"Executing tool '{tool_name}' in fallback mode with vendors: {vendors_to_try}") + for vendor_name in vendors_to_try: - # Get the vendor function directly from registry vendor_func = vendor_functions.get(vendor_name) if not vendor_func: @@ -80,16 +69,131 @@ def execute_tool(tool_name: str, *args, **kwargs) -> Any: error_msg = f"Vendor '{vendor_name}' failed: {str(e)}" logger.warning(f"Tool '{tool_name}': {error_msg}") errors.append(error_msg) - - # Continue to next vendor (fallback) continue - # Step 4: All vendors failed + # All vendors failed error_summary = f"Tool '{tool_name}' failed with all vendors:\n" + "\n".join(f" - {err}" for err in errors) logger.error(error_summary) raise ToolExecutionError(error_summary) +def _execute_aggregate(tool_name: str, vendor_config: Dict, metadata: Dict, *args, **kwargs) -> str: + """Execute multiple vendors in parallel and aggregate results. + + Executes all specified vendors simultaneously using ThreadPoolExecutor, + collects successful results, and combines them with vendor labels. + + Args: + tool_name: Name of the tool + vendor_config: Vendor configuration from registry + metadata: Tool metadata from registry + *args: Positional arguments for vendor functions + **kwargs: Keyword arguments for vendor functions + + Returns: + Aggregated results from all successful vendors, formatted with labels + + Raises: + ToolExecutionError: If all vendors fail + """ + vendor_functions = vendor_config["vendors"] + + # Get list of vendors to aggregate (default to all in priority list) + vendors_to_aggregate = metadata.get("aggregate_vendors") or vendor_config["vendor_priority"] + + logger.debug(f"Executing tool '{tool_name}' in aggregate mode with vendors: {vendors_to_aggregate}") + + results = [] + errors = [] + + # Execute vendors in parallel using ThreadPoolExecutor + with concurrent.futures.ThreadPoolExecutor(max_workers=len(vendors_to_aggregate)) as executor: + # Submit all vendor calls + future_to_vendor = {} + for vendor_name in vendors_to_aggregate: + vendor_func = vendor_functions.get(vendor_name) + if vendor_func: + future = executor.submit(vendor_func, *args, **kwargs) + future_to_vendor[future] = vendor_name + else: + logger.warning(f"Vendor '{vendor_name}' not found in vendors dict for tool '{tool_name}'") + + # Collect results as they complete + for future in concurrent.futures.as_completed(future_to_vendor): + vendor_name = future_to_vendor[future] + try: + result = future.result() + results.append({ + "vendor": vendor_name, + "data": result + }) + logger.debug(f"Tool '{tool_name}': vendor '{vendor_name}' succeeded") + except Exception as e: + error_msg = f"Vendor '{vendor_name}' failed: {str(e)}" + errors.append(error_msg) + logger.warning(f"Tool '{tool_name}': {error_msg}") + + # Check if we got any results + if not results: + error_summary = f"Tool '{tool_name}' aggregate mode: all vendors failed:\n" + "\n".join(f" - {err}" for err in errors) + logger.error(error_summary) + raise ToolExecutionError(error_summary) + + # Format aggregated results with clear vendor labels + formatted_results = [] + for item in results: + vendor_label = f"=== {item['vendor'].upper()} ===" + formatted_results.append(f"{vendor_label}\n{item['data']}") + + # Log partial success if some vendors failed + if errors: + logger.info(f"Tool '{tool_name}': {len(results)} vendors succeeded, {len(errors)} failed") + + return "\n\n".join(formatted_results) + + +def execute_tool(tool_name: str, *args, **kwargs) -> Any: + """Execute a tool using fallback or aggregate mode based on configuration. + + This is the main entry point for tool execution. It dispatches to either + fallback mode (sequential with early return) or aggregate mode (parallel + with result combination) based on the tool's execution_mode setting. + + Args: + tool_name: Name of the tool to execute (e.g., "get_stock_data") + *args: Positional arguments to pass to the tool + **kwargs: Keyword arguments to pass to the tool + + Returns: + Result from vendor function(s). String for aggregate mode (formatted + with vendor labels), Any for fallback mode (raw vendor result). + + Raises: + VendorNotFoundError: If tool or vendor implementation not found + ToolExecutionError: If all vendors fail to execute the tool + """ + # Get vendor configuration and metadata from registry + vendor_config = get_vendor_config(tool_name) + metadata = get_tool_metadata(tool_name) + + if not vendor_config["vendor_priority"]: + raise VendorNotFoundError( + f"Tool '{tool_name}' not found in registry or has no vendors configured" + ) + + if not metadata: + raise VendorNotFoundError(f"Tool '{tool_name}' metadata not found in registry") + + # Check execution mode (defaults to fallback for backward compatibility) + execution_mode = metadata.get("execution_mode", "fallback") + + # Dispatch to appropriate execution strategy + if execution_mode == "aggregate": + return _execute_aggregate(tool_name, vendor_config, metadata, *args, **kwargs) + else: + return _execute_fallback(tool_name, vendor_config, *args, **kwargs) + + def get_tool_info(tool_name: str) -> Optional[dict]: """Get information about a tool from the registry. diff --git a/tradingagents/tools/generator.py b/tradingagents/tools/generator.py index b26e1ef6..6fc96e88 100644 --- a/tradingagents/tools/generator.py +++ b/tradingagents/tools/generator.py @@ -54,6 +54,10 @@ def generate_langchain_tool(tool_name: str, metadata: Dict[str, Any]) -> Callabl # Use **kwargs to handle all parameters def tool_function(**kwargs): """Dynamically generated tool function.""" + # Ensure defaults are applied for missing parameters + for param_name, param_info in parameters.items(): + if param_name not in kwargs and "default" in param_info: + kwargs[param_name] = param_info["default"] return execute_tool(tool_name, **kwargs) # Set function metadata diff --git a/tradingagents/tools/registry.py b/tradingagents/tools/registry.py index 0f2fcb0c..8fa24382 100644 --- a/tradingagents/tools/registry.py +++ b/tradingagents/tools/registry.py @@ -53,6 +53,8 @@ from tradingagents.dataflows.reddit_api import ( ) from tradingagents.dataflows.finnhub_api import ( get_recommendation_trends as get_finnhub_recommendation_trends, + get_earnings_calendar as get_finnhub_earnings_calendar, + get_ipo_calendar as get_finnhub_ipo_calendar, ) from tradingagents.dataflows.twitter_data import ( get_tweets as get_twitter_tweets, @@ -109,6 +111,8 @@ TOOL_REGISTRY: Dict[str, Dict[str, Any]] = { "alpha_vantage": get_alpha_vantage_indicator, }, "vendor_priority": ["yfinance", "alpha_vantage"], + "execution_mode": "aggregate", + "aggregate_vendors": ["yfinance", "alpha_vantage"], "parameters": { "symbol": {"type": "str", "description": "Ticker symbol"}, "indicator": {"type": "str", "description": "Technical indicator (rsi, macd, sma, ema, etc.)"}, @@ -208,6 +212,8 @@ TOOL_REGISTRY: Dict[str, Dict[str, Any]] = { "google": get_google_news, }, "vendor_priority": ["alpha_vantage", "reddit", "openai", "google"], + "execution_mode": "aggregate", + "aggregate_vendors": ["alpha_vantage", "reddit", "google"], "parameters": { "query": {"type": "str", "description": "Search query or ticker symbol"}, "start_date": {"type": "str", "description": "Start date, yyyy-mm-dd"}, @@ -227,6 +233,7 @@ TOOL_REGISTRY: Dict[str, Dict[str, Any]] = { "alpha_vantage": get_alpha_vantage_global_news, }, "vendor_priority": ["openai", "google", "reddit", "alpha_vantage"], + "execution_mode": "aggregate", "parameters": { "date": {"type": "str", "description": "Date for news, yyyy-mm-dd"}, "look_back_days": {"type": "int", "description": "Days to look back", "default": 7}, @@ -310,6 +317,36 @@ TOOL_REGISTRY: Dict[str, Dict[str, Any]] = { "returns": "str: Tweets matching the query", }, + "get_earnings_calendar": { + "description": "Get upcoming earnings announcements (catalysts for volatility)", + "category": "discovery", + "agents": [], + "vendors": { + "finnhub": get_finnhub_earnings_calendar, + }, + "vendor_priority": ["finnhub"], + "parameters": { + "from_date": {"type": "str", "description": "Start date in yyyy-mm-dd format"}, + "to_date": {"type": "str", "description": "End date in yyyy-mm-dd format"}, + }, + "returns": "str: Formatted earnings calendar with EPS and revenue estimates", + }, + + "get_ipo_calendar": { + "description": "Get upcoming and recent IPOs (new listing opportunities)", + "category": "discovery", + "agents": [], + "vendors": { + "finnhub": get_finnhub_ipo_calendar, + }, + "vendor_priority": ["finnhub"], + "parameters": { + "from_date": {"type": "str", "description": "Start date in yyyy-mm-dd format"}, + "to_date": {"type": "str", "description": "End date in yyyy-mm-dd format"}, + }, + "returns": "str: Formatted IPO calendar with pricing and share details", + }, + "get_reddit_discussions": { "description": "Get Reddit discussions about a specific ticker", "category": "news_data", @@ -465,6 +502,26 @@ def validate_registry() -> List[str]: if not isinstance(metadata.get("parameters"), dict): issues.append(f"{tool_name}: Parameters must be a dictionary") + # Validate execution_mode if present + if "execution_mode" in metadata: + execution_mode = metadata["execution_mode"] + if execution_mode not in ["fallback", "aggregate"]: + issues.append(f"{tool_name}: Invalid execution_mode '{execution_mode}', must be 'fallback' or 'aggregate'") + + # Validate aggregate_vendors if present + if "aggregate_vendors" in metadata: + aggregate_vendors = metadata["aggregate_vendors"] + if not isinstance(aggregate_vendors, list): + issues.append(f"{tool_name}: aggregate_vendors must be a list") + else: + for vendor_name in aggregate_vendors: + if vendor_name not in vendors: + issues.append(f"{tool_name}: aggregate_vendor '{vendor_name}' not in vendors dict") + + # Warn if aggregate_vendors specified but execution_mode is not aggregate + if metadata.get("execution_mode") != "aggregate": + issues.append(f"{tool_name}: aggregate_vendors specified but execution_mode is not 'aggregate'") + return issues