TradingAgents/tradingagents/tools/executor.py

307 lines
10 KiB
Python

"""
Tool Executor - Simplified Tool Execution with Registry-Based Routing
This module replaces the complex route_to_vendor() function with a simpler,
registry-based approach. All routing decisions are driven by the tool registry.
Key improvements over old system:
- Single registry lookup instead of multiple dictionary lookups
- Supports both fallback and aggregate execution modes
- Parallel vendor execution for aggregate mode
- Better error messages and debugging
- No dual registry systems
"""
import concurrent.futures
from typing import Any, Dict, List, Optional
from tradingagents.tools.registry import TOOL_REGISTRY, get_tool_metadata, get_vendor_config
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class ToolExecutionError(Exception):
"""Raised when tool execution fails across all vendors."""
pass
class VendorNotFoundError(Exception):
"""Raised when no vendor implementation is found for a tool."""
pass
def _execute_fallback(tool_name: str, vendor_config: Dict, *args, **kwargs) -> Any:
"""Execute vendors sequentially with fallback (original behavior).
Tries vendors in priority order and returns the first successful result.
Args:
tool_name: Name of the tool
vendor_config: Vendor configuration from registry
*args: Positional arguments for vendor function
**kwargs: Keyword arguments for vendor function
Returns:
Result from first successful vendor
Raises:
ToolExecutionError: If all vendors fail
"""
vendor_functions = vendor_config["vendors"]
vendors_to_try = vendor_config["vendor_priority"]
errors = []
logger.debug(f"Executing tool '{tool_name}' in fallback mode with vendors: {vendors_to_try}")
for vendor_name in vendors_to_try:
vendor_func = vendor_functions.get(vendor_name)
if not vendor_func:
logger.warning(f"Vendor '{vendor_name}' not found in registry for tool '{tool_name}'")
continue
try:
result = vendor_func(*args, **kwargs)
logger.debug(f"Tool '{tool_name}' succeeded with vendor '{vendor_name}'")
return result
except Exception as e:
error_msg = f"Vendor '{vendor_name}' failed: {str(e)}"
logger.warning(f"Tool '{tool_name}': {error_msg}")
errors.append(error_msg)
continue
# All vendors failed
error_summary = f"Tool '{tool_name}' failed with all vendors:\n" + "\n".join(
f" - {err}" for err in errors
)
logger.error(error_summary)
raise ToolExecutionError(error_summary)
def _execute_aggregate(tool_name: str, vendor_config: Dict, metadata: Dict, *args, **kwargs) -> str:
"""Execute multiple vendors in parallel and aggregate results.
Executes all specified vendors simultaneously using ThreadPoolExecutor,
collects successful results, and combines them with vendor labels.
Args:
tool_name: Name of the tool
vendor_config: Vendor configuration from registry
metadata: Tool metadata from registry
*args: Positional arguments for vendor functions
**kwargs: Keyword arguments for vendor functions
Returns:
Aggregated results from all successful vendors, formatted with labels
Raises:
ToolExecutionError: If all vendors fail
"""
vendor_functions = vendor_config["vendors"]
# Get list of vendors to aggregate (default to all in priority list)
vendors_to_aggregate = metadata.get("aggregate_vendors") or vendor_config["vendor_priority"]
logger.debug(
f"Executing tool '{tool_name}' in aggregate mode with vendors: {vendors_to_aggregate}"
)
results = []
errors = []
# Execute vendors in parallel using ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor(max_workers=len(vendors_to_aggregate)) as executor:
# Submit all vendor calls
future_to_vendor = {}
for vendor_name in vendors_to_aggregate:
vendor_func = vendor_functions.get(vendor_name)
if vendor_func:
future = executor.submit(vendor_func, *args, **kwargs)
future_to_vendor[future] = vendor_name
else:
logger.warning(
f"Vendor '{vendor_name}' not found in vendors dict for tool '{tool_name}'"
)
# Collect results as they complete
for future in concurrent.futures.as_completed(future_to_vendor):
vendor_name = future_to_vendor[future]
try:
result = future.result()
results.append({"vendor": vendor_name, "data": result})
logger.debug(f"Tool '{tool_name}': vendor '{vendor_name}' succeeded")
except Exception as e:
error_msg = f"Vendor '{vendor_name}' failed: {str(e)}"
errors.append(error_msg)
logger.warning(f"Tool '{tool_name}': {error_msg}")
# Check if we got any results
if not results:
error_summary = f"Tool '{tool_name}' aggregate mode: all vendors failed:\n" + "\n".join(
f" - {err}" for err in errors
)
logger.error(error_summary)
raise ToolExecutionError(error_summary)
# Format aggregated results with clear vendor labels
formatted_results = []
for item in results:
vendor_label = f"=== {item['vendor'].upper()} ==="
formatted_results.append(f"{vendor_label}\n{item['data']}")
# Log partial success if some vendors failed
if errors:
logger.info(f"Tool '{tool_name}': {len(results)} vendors succeeded, {len(errors)} failed")
return "\n\n".join(formatted_results)
def execute_tool(tool_name: str, *args, **kwargs) -> Any:
"""Execute a tool using fallback or aggregate mode based on configuration.
This is the main entry point for tool execution. It dispatches to either
fallback mode (sequential with early return) or aggregate mode (parallel
with result combination) based on the tool's execution_mode setting.
Args:
tool_name: Name of the tool to execute (e.g., "get_stock_data")
*args: Positional arguments to pass to the tool
**kwargs: Keyword arguments to pass to the tool
Returns:
Result from vendor function(s). String for aggregate mode (formatted
with vendor labels), Any for fallback mode (raw vendor result).
Raises:
VendorNotFoundError: If tool or vendor implementation not found
ToolExecutionError: If all vendors fail to execute the tool
"""
# Get vendor configuration and metadata from registry
vendor_config = get_vendor_config(tool_name)
metadata = get_tool_metadata(tool_name)
if not vendor_config["vendor_priority"]:
raise VendorNotFoundError(
f"Tool '{tool_name}' not found in registry or has no vendors configured"
)
if not metadata:
raise VendorNotFoundError(f"Tool '{tool_name}' metadata not found in registry")
# Check execution mode (defaults to fallback for backward compatibility)
execution_mode = metadata.get("execution_mode", "fallback")
# Dispatch to appropriate execution strategy
if execution_mode == "aggregate":
return _execute_aggregate(tool_name, vendor_config, metadata, *args, **kwargs)
else:
return _execute_fallback(tool_name, vendor_config, *args, **kwargs)
def get_tool_info(tool_name: str) -> Optional[dict]:
"""Get information about a tool from the registry.
Useful for debugging and introspection.
Args:
tool_name: Name of the tool
Returns:
Tool metadata dict, or None if not found
"""
return TOOL_REGISTRY.get(tool_name)
def list_available_vendors(tool_name: str) -> List[str]:
"""List all available vendors for a tool.
Args:
tool_name: Name of the tool
Returns:
List of vendor names in priority order
"""
vendor_config = get_vendor_config(tool_name)
return vendor_config.get("vendor_priority", [])
# ============================================================================
# LEGACY COMPATIBILITY LAYER
# ============================================================================
def route_to_vendor(method: str, *args, **kwargs) -> Any:
"""Legacy compatibility function.
This provides backward compatibility with the old route_to_vendor() calls.
Internally, it just delegates to execute_tool().
DEPRECATED: Use execute_tool() directly in new code.
Args:
method: Tool name (legacy parameter name)
*args: Positional arguments
**kwargs: Keyword arguments
Returns:
Result from tool execution
"""
logger.warning(f"route_to_vendor() is deprecated. Use execute_tool('{method}', ...) instead.")
return execute_tool(method, *args, **kwargs)
# ============================================================================
# TESTING & DEBUGGING
# ============================================================================
if __name__ == "__main__":
# Enable debug logging
import logging
logging.basicConfig(level=logging.DEBUG)
logger.info("=" * 70)
logger.info("TOOL EXECUTOR - TESTING")
logger.info("=" * 70)
# Test 1: List available vendors for each tool
logger.info("Available vendors per tool:")
from tradingagents.tools.registry import get_all_tools
for tool_name in get_all_tools():
vendors = list_available_vendors(tool_name)
logger.info(f" {tool_name}:")
logger.info(f" Primary: {vendors[0] if vendors else 'None'}")
if len(vendors) > 1:
logger.info(f" Fallbacks: {', '.join(vendors[1:])}")
# Test 2: Show tool info
logger.info("Tool info examples:")
for tool_name in ["get_stock_data", "get_news", "get_fundamentals"]:
info = get_tool_info(tool_name)
if info:
logger.info(f" {tool_name}:")
logger.info(f" Category: {info['category']}")
logger.info(f" Agents: {', '.join(info['agents']) if info['agents'] else 'None'}")
logger.info(f" Description: {info['description']}")
# Test 3: Validate registry
logger.info("Validating registry:")
from tradingagents.tools.registry import validate_registry
issues = validate_registry()
if issues:
logger.warning("⚠️ Registry validation issues found:")
for issue in issues[:10]: # Show first 10
logger.warning(f" - {issue}")
if len(issues) > 10:
logger.warning(f" ... and {len(issues) - 10} more")
else:
logger.info("✅ Registry is valid!")
logger.info("=" * 70)