""" Tool Executor - Simplified Tool Execution with Registry-Based Routing This module replaces the complex route_to_vendor() function with a simpler, registry-based approach. All routing decisions are driven by the tool registry. Key improvements over old system: - Single registry lookup instead of multiple dictionary lookups - Supports both fallback and aggregate execution modes - Parallel vendor execution for aggregate mode - Better error messages and debugging - No dual registry systems """ import concurrent.futures from typing import Any, Dict, List, Optional from tradingagents.tools.registry import TOOL_REGISTRY, get_tool_metadata, get_vendor_config from tradingagents.utils.logger import get_logger logger = get_logger(__name__) class ToolExecutionError(Exception): """Raised when tool execution fails across all vendors.""" pass class VendorNotFoundError(Exception): """Raised when no vendor implementation is found for a tool.""" pass def _execute_fallback(tool_name: str, vendor_config: Dict, *args, **kwargs) -> Any: """Execute vendors sequentially with fallback (original behavior). Tries vendors in priority order and returns the first successful result. Args: tool_name: Name of the tool vendor_config: Vendor configuration from registry *args: Positional arguments for vendor function **kwargs: Keyword arguments for vendor function Returns: Result from first successful vendor Raises: ToolExecutionError: If all vendors fail """ vendor_functions = vendor_config["vendors"] vendors_to_try = vendor_config["vendor_priority"] errors = [] logger.debug(f"Executing tool '{tool_name}' in fallback mode with vendors: {vendors_to_try}") for vendor_name in vendors_to_try: vendor_func = vendor_functions.get(vendor_name) if not vendor_func: logger.warning(f"Vendor '{vendor_name}' not found in registry for tool '{tool_name}'") continue try: result = vendor_func(*args, **kwargs) logger.debug(f"Tool '{tool_name}' succeeded with vendor '{vendor_name}'") return result except Exception as e: error_msg = f"Vendor '{vendor_name}' failed: {str(e)}" logger.warning(f"Tool '{tool_name}': {error_msg}") errors.append(error_msg) continue # All vendors failed error_summary = f"Tool '{tool_name}' failed with all vendors:\n" + "\n".join( f" - {err}" for err in errors ) logger.error(error_summary) raise ToolExecutionError(error_summary) def _execute_aggregate(tool_name: str, vendor_config: Dict, metadata: Dict, *args, **kwargs) -> str: """Execute multiple vendors in parallel and aggregate results. Executes all specified vendors simultaneously using ThreadPoolExecutor, collects successful results, and combines them with vendor labels. Args: tool_name: Name of the tool vendor_config: Vendor configuration from registry metadata: Tool metadata from registry *args: Positional arguments for vendor functions **kwargs: Keyword arguments for vendor functions Returns: Aggregated results from all successful vendors, formatted with labels Raises: ToolExecutionError: If all vendors fail """ vendor_functions = vendor_config["vendors"] # Get list of vendors to aggregate (default to all in priority list) vendors_to_aggregate = metadata.get("aggregate_vendors") or vendor_config["vendor_priority"] logger.debug( f"Executing tool '{tool_name}' in aggregate mode with vendors: {vendors_to_aggregate}" ) results = [] errors = [] # Execute vendors in parallel using ThreadPoolExecutor with concurrent.futures.ThreadPoolExecutor(max_workers=len(vendors_to_aggregate)) as executor: # Submit all vendor calls future_to_vendor = {} for vendor_name in vendors_to_aggregate: vendor_func = vendor_functions.get(vendor_name) if vendor_func: future = executor.submit(vendor_func, *args, **kwargs) future_to_vendor[future] = vendor_name else: logger.warning( f"Vendor '{vendor_name}' not found in vendors dict for tool '{tool_name}'" ) # Collect results as they complete for future in concurrent.futures.as_completed(future_to_vendor): vendor_name = future_to_vendor[future] try: result = future.result() results.append({"vendor": vendor_name, "data": result}) logger.debug(f"Tool '{tool_name}': vendor '{vendor_name}' succeeded") except Exception as e: error_msg = f"Vendor '{vendor_name}' failed: {str(e)}" errors.append(error_msg) logger.warning(f"Tool '{tool_name}': {error_msg}") # Check if we got any results if not results: error_summary = f"Tool '{tool_name}' aggregate mode: all vendors failed:\n" + "\n".join( f" - {err}" for err in errors ) logger.error(error_summary) raise ToolExecutionError(error_summary) # Format aggregated results with clear vendor labels formatted_results = [] for item in results: vendor_label = f"=== {item['vendor'].upper()} ===" formatted_results.append(f"{vendor_label}\n{item['data']}") # Log partial success if some vendors failed if errors: logger.info(f"Tool '{tool_name}': {len(results)} vendors succeeded, {len(errors)} failed") return "\n\n".join(formatted_results) def execute_tool(tool_name: str, *args, **kwargs) -> Any: """Execute a tool using fallback or aggregate mode based on configuration. This is the main entry point for tool execution. It dispatches to either fallback mode (sequential with early return) or aggregate mode (parallel with result combination) based on the tool's execution_mode setting. Args: tool_name: Name of the tool to execute (e.g., "get_stock_data") *args: Positional arguments to pass to the tool **kwargs: Keyword arguments to pass to the tool Returns: Result from vendor function(s). String for aggregate mode (formatted with vendor labels), Any for fallback mode (raw vendor result). Raises: VendorNotFoundError: If tool or vendor implementation not found ToolExecutionError: If all vendors fail to execute the tool """ # Get vendor configuration and metadata from registry vendor_config = get_vendor_config(tool_name) metadata = get_tool_metadata(tool_name) if not vendor_config["vendor_priority"]: raise VendorNotFoundError( f"Tool '{tool_name}' not found in registry or has no vendors configured" ) if not metadata: raise VendorNotFoundError(f"Tool '{tool_name}' metadata not found in registry") # Check execution mode (defaults to fallback for backward compatibility) execution_mode = metadata.get("execution_mode", "fallback") # Dispatch to appropriate execution strategy if execution_mode == "aggregate": return _execute_aggregate(tool_name, vendor_config, metadata, *args, **kwargs) else: return _execute_fallback(tool_name, vendor_config, *args, **kwargs) def get_tool_info(tool_name: str) -> Optional[dict]: """Get information about a tool from the registry. Useful for debugging and introspection. Args: tool_name: Name of the tool Returns: Tool metadata dict, or None if not found """ return TOOL_REGISTRY.get(tool_name) def list_available_vendors(tool_name: str) -> List[str]: """List all available vendors for a tool. Args: tool_name: Name of the tool Returns: List of vendor names in priority order """ vendor_config = get_vendor_config(tool_name) return vendor_config.get("vendor_priority", []) # ============================================================================ # LEGACY COMPATIBILITY LAYER # ============================================================================ def route_to_vendor(method: str, *args, **kwargs) -> Any: """Legacy compatibility function. This provides backward compatibility with the old route_to_vendor() calls. Internally, it just delegates to execute_tool(). DEPRECATED: Use execute_tool() directly in new code. Args: method: Tool name (legacy parameter name) *args: Positional arguments **kwargs: Keyword arguments Returns: Result from tool execution """ logger.warning(f"route_to_vendor() is deprecated. Use execute_tool('{method}', ...) instead.") return execute_tool(method, *args, **kwargs) # ============================================================================ # TESTING & DEBUGGING # ============================================================================ if __name__ == "__main__": # Enable debug logging import logging logging.basicConfig(level=logging.DEBUG) logger.info("=" * 70) logger.info("TOOL EXECUTOR - TESTING") logger.info("=" * 70) # Test 1: List available vendors for each tool logger.info("Available vendors per tool:") from tradingagents.tools.registry import get_all_tools for tool_name in get_all_tools(): vendors = list_available_vendors(tool_name) logger.info(f" {tool_name}:") logger.info(f" Primary: {vendors[0] if vendors else 'None'}") if len(vendors) > 1: logger.info(f" Fallbacks: {', '.join(vendors[1:])}") # Test 2: Show tool info logger.info("Tool info examples:") for tool_name in ["get_stock_data", "get_news", "get_fundamentals"]: info = get_tool_info(tool_name) if info: logger.info(f" {tool_name}:") logger.info(f" Category: {info['category']}") logger.info(f" Agents: {', '.join(info['agents']) if info['agents'] else 'None'}") logger.info(f" Description: {info['description']}") # Test 3: Validate registry logger.info("Validating registry:") from tradingagents.tools.registry import validate_registry issues = validate_registry() if issues: logger.warning("⚠️ Registry validation issues found:") for issue in issues[:10]: # Show first 10 logger.warning(f" - {issue}") if len(issues) > 10: logger.warning(f" ... and {len(issues) - 10} more") else: logger.info("✅ Registry is valid!") logger.info("=" * 70)