TradingAgents/tradingagents/integrations/sequential_mcp/client.py

208 lines
9.2 KiB
Python

"""Client for delegating planning to the Sequential Thinking MCP server."""
from __future__ import annotations
import asyncio
import logging
import shlex
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Optional
from .config import SequentialMCPConfig
from ..mcp_handshake import emit_console, perform_handshake
try: # pragma: no cover - optional dependency during linting
from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.client.stdio import StdioServerParameters, stdio_client
except ImportError: # pragma: no cover - surfaced at runtime with helpful error
ClientSession = None # type: ignore[assignment]
streamablehttp_client = None # type: ignore[assignment]
stdio_client = None # type: ignore[assignment]
StdioServerParameters = None # type: ignore[assignment]
class SequentialMCPError(RuntimeError):
"""Raised when the Sequential Thinking MCP client cannot satisfy a request."""
class SequentialMCPClient:
"""Simple interface for requesting action plans from the Sequential Thinking MCP server."""
def __init__(self, config: SequentialMCPConfig, logger: Optional[logging.Logger] = None) -> None:
self.config = config
self.logger = logger or logging.getLogger(__name__)
def generate_plan(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Request a plan from the sequential thinking server."""
return self._call_tool("sequential_thinking", payload)
def verify_connection(self) -> bool:
"""Check that the Sequential MCP server is reachable and exposes its tool."""
if not self.config.enabled:
msg = "Sequential MCP disabled; skipping connectivity check."
self.logger.info(msg)
emit_console("INFO", msg)
return False
try:
return asyncio.run(self._verify_async())
except SequentialMCPError as exc:
msg = f"Sequential MCP connectivity probe failed: {exc}"
self.logger.warning(msg)
emit_console("WARNING", msg)
return False
except Exception as exc: # pragma: no cover - diagnostic logging only
msg = f"Sequential MCP connectivity probe failed: {exc}"
self.logger.warning(msg)
emit_console("WARNING", msg)
return False
async def _verify_async(self) -> bool:
async with self._acquire_session() as session:
tools_response = await session.list_tools()
available = [getattr(tool, "name", "") for tool in getattr(tools_response, "tools", [])]
missing = self.config.required_toolset(available)
if missing:
msg = "Sequential MCP connected but missing required tools: " + ", ".join(missing)
self.logger.warning(msg)
emit_console("WARNING", msg)
return False
tools_list = ", ".join(sorted(filter(None, available)))
msg = f"Sequential MCP connectivity verified (tools={tools_list})"
self.logger.info(msg)
emit_console("INFO", msg)
return True
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _call_tool(self, tool_name: str, arguments: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
if not self.config.enabled:
raise SequentialMCPError("Sequential MCP integration is disabled.")
if ClientSession is None or (
self.config.transport == "http" and streamablehttp_client is None
):
raise SequentialMCPError(
"Package 'mcp' is required to use the Sequential Thinking MCP integration. Install it with `pip install mcp`."
)
self.config.validate()
payload = arguments or {}
return asyncio.run(self._call_tool_async(tool_name, payload))
async def _call_tool_async(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
try:
async with self._acquire_session() as session:
tools_response = await session.list_tools()
available = [getattr(tool, "name", "") for tool in getattr(tools_response, "tools", [])]
missing = self.config.required_toolset(available)
if missing:
raise SequentialMCPError(
"Sequential MCP server is missing required tools: " + ", ".join(missing)
)
if tool_name not in available:
raise SequentialMCPError(f"Sequential MCP server does not expose tool '{tool_name}'.")
result = await session.call_tool(tool_name, arguments)
return self._extract_content(result)
except SequentialMCPError:
raise
except BaseExceptionGroup as exc_group: # pragma: no cover
message = self._flatten_exception_message(exc_group)
raise SequentialMCPError(f"Failed to call Sequential MCP tool '{tool_name}': {message}") from exc_group
except Exception as exc:
raise SequentialMCPError(f"Failed to call Sequential MCP tool '{tool_name}': {exc}") from exc
@asynccontextmanager
async def _acquire_session(self) -> "ClientSession":
if ClientSession is None:
raise SequentialMCPError(
"Package 'mcp' is required to use the Sequential Thinking MCP integration. Install it with `pip install mcp`."
)
if self.config.transport == "http":
if streamablehttp_client is None:
raise SequentialMCPError(
"HTTP transport requires the 'mcp' package. Install it with `pip install mcp`."
)
base_url = self._build_http_base()
self.logger.debug("Connecting to Sequential MCP via HTTP at %s", base_url)
async with streamablehttp_client(
url=base_url,
timeout=self.config.timeout_seconds,
) as (read_stream, write_stream, _session_id_cb):
async with ClientSession(read_stream, write_stream) as session:
await perform_handshake(
session,
client_label="Sequential",
logger=self.logger,
)
yield session
return
if self.config.transport == "stdio":
if stdio_client is None or StdioServerParameters is None:
raise SequentialMCPError(
"STDIO transport requires the 'mcp' package. Install it with `pip install mcp`."
)
command = self.config.command or "python -m tradingagents.integrations.sequential_mcp.server"
args = shlex.split(command)
if not args:
raise SequentialMCPError("STDIO command is empty.")
params = StdioServerParameters(command=args[0], args=args[1:])
self.logger.debug("Launching Sequential MCP via STDIO: %s", args)
async with stdio_client(params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await perform_handshake(
session,
client_label="Sequential",
logger=self.logger,
)
yield session
return
raise SequentialMCPError(f"Unsupported transport '{self.config.transport}'.")
def _build_http_base(self) -> str:
if self.config.base_url:
return self.config.base_url.rstrip("/")
host = self.config.host
if host.startswith("http://") or host.startswith("https://"):
return host.rstrip("/")
return f"http://{host}:{self.config.port}/mcp"
@staticmethod
def _extract_content(result: Any) -> Dict[str, Any]:
content = getattr(result, "content", None)
if content is None and isinstance(result, dict):
content = result.get("content")
fragments: List[str] = []
if content:
for item in content:
text_value = getattr(item, "text", None)
if text_value is None and isinstance(item, dict):
text_value = item.get("text")
fragments.append(str(text_value) if text_value is not None else str(item))
text = "\n".join(fragment for fragment in fragments if fragment)
structured = getattr(result, "structured_content", None)
if structured is None and isinstance(result, dict):
structured = result.get("structured_content") or result.get("structuredContent")
return {"text": text, "structured": structured, "raw": result}
@staticmethod
def _flatten_exception_message(exc: BaseException) -> str:
if isinstance(exc, BaseExceptionGroup):
parts: List[str] = []
for item in exc.exceptions:
message = SequentialMCPClient._flatten_exception_message(item)
if message:
parts.append(message)
return "; ".join(parts)
return str(exc)