diff --git a/scripts/smoke_claude_agent_analyst.py b/scripts/smoke_claude_agent_analyst.py new file mode 100644 index 00000000..6e928292 --- /dev/null +++ b/scripts/smoke_claude_agent_analyst.py @@ -0,0 +1,51 @@ +"""Shape B smoke test: run the market analyst end-to-end against a ticker/date +via the claude_agent provider, confirming MCP tool translation and the SDK- +native tool loop. + +Requires the user to be logged into Claude Code. +""" + +import time + +from langchain_core.messages import HumanMessage + +from tradingagents.agents.analysts.market_analyst import create_market_analyst +from tradingagents.llm_clients.factory import create_llm_client + + +def main(): + client = create_llm_client(provider="claude_agent", model="sonnet") + llm = client.get_llm() + + node = create_market_analyst(llm) + + state = { + "trade_date": "2025-10-15", + "company_of_interest": "AAPL", + "messages": [ + HumanMessage( + content=( + "Produce a concise market analysis report for AAPL based on " + "the most recent price data and a few key technical indicators. " + "Keep it under 500 words." + ) + ) + ], + } + + print(f"Running market analyst on {state['company_of_interest']} @ {state['trade_date']}...") + start = time.monotonic() + output = node(state) + elapsed = time.monotonic() - start + + report = output.get("market_report", "") + print(f"\n--- market_report ({elapsed:.1f}s, {len(report)} chars) ---\n") + print(report) + + assert report, "market_report is empty" + assert "messages" in output and len(output["messages"]) == 1 + print("\nShape B smoke test OK.") + + +if __name__ == "__main__": + main() diff --git a/tradingagents/agents/analysts/_claude_agent_runner.py b/tradingagents/agents/analysts/_claude_agent_runner.py new file mode 100644 index 00000000..468ee011 --- /dev/null +++ b/tradingagents/agents/analysts/_claude_agent_runner.py @@ -0,0 +1,100 @@ +"""SDK-native analyst runner. + +When the configured LLM is :class:`ChatClaudeAgent`, the analyst node delegates +the whole tool-calling loop to ``claude-agent-sdk``. The SDK owns the loop: +Claude iteratively invokes the translated MCP tools and returns a final text +report. No LangGraph ToolNode involvement — the analyst returns a terminal +AIMessage with zero tool_calls, so the existing conditional edges route +straight to the message-clear node. +""" + +import asyncio +from typing import Any, Dict, List + +from langchain_core.messages import AIMessage, HumanMessage + +from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent +from tradingagents.llm_clients.mcp_tool_adapter import build_mcp_server + + +def _build_user_prompt(state: Dict[str, Any]) -> str: + """Extract any human content from the incoming message sequence. + + Existing analysts rely on LangGraph feeding tool-call round trips through + state["messages"]. On the SDK path we collapse the incoming messages into a + single user prompt — tool results are consumed by the SDK loop, not via + LangGraph, so only the human-authored content matters here. + """ + parts: List[str] = [] + for msg in state.get("messages", []): + if isinstance(msg, HumanMessage): + content = msg.content + if isinstance(content, str) and content.strip(): + parts.append(content.strip()) + if not parts: + parts.append("Produce the requested report.") + return "\n\n".join(parts) + + +async def _run( + system_prompt: str, + user_prompt: str, + lc_tools: List[Any], + server_name: str, + model: str, +) -> str: + from claude_agent_sdk import ( + AssistantMessage, + ClaudeAgentOptions, + TextBlock, + query, + ) + + server, allowed = build_mcp_server(server_name, lc_tools) + + options = ClaudeAgentOptions( + model=model, + system_prompt=system_prompt, + mcp_servers={server_name: server}, + allowed_tools=allowed, + # Block the Claude Code built-ins; only our MCP tools should run. + disallowed_tools=[ + "Bash", "Read", "Write", "Edit", "MultiEdit", + "Glob", "Grep", "WebFetch", "WebSearch", + "Task", "TodoWrite", "NotebookEdit", + ], + permission_mode="bypassPermissions", + ) + + text_parts: List[str] = [] + async for msg in query(prompt=user_prompt, options=options): + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + text_parts.append(block.text) + return "\n".join(text_parts).strip() + + +def run_sdk_analyst( + llm: ChatClaudeAgent, + state: Dict[str, Any], + system_prompt: str, + lc_tools: List[Any], + server_name: str, + report_field: str, +) -> Dict[str, Any]: + """Run an analyst through the Claude Agent SDK tool loop and build the node output.""" + user_prompt = _build_user_prompt(state) + report = asyncio.run( + _run( + system_prompt=system_prompt, + user_prompt=user_prompt, + lc_tools=lc_tools, + server_name=server_name, + model=llm.model, + ) + ) + return { + "messages": [AIMessage(content=report)], + report_field: report, + } diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index 6aa49cf3..837627c1 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -1,4 +1,5 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst from tradingagents.agents.utils.agent_utils import ( build_instrument_context, get_balance_sheet, @@ -9,6 +10,7 @@ from tradingagents.agents.utils.agent_utils import ( get_language_instruction, ) from tradingagents.dataflows.config import get_config +from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent def create_fundamentals_analyst(llm): @@ -52,6 +54,23 @@ def create_fundamentals_analyst(llm): prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(instrument_context=instrument_context) + if isinstance(llm, ChatClaudeAgent): + full_system = ( + "You are a helpful AI assistant. Use the provided tools to progress towards " + "producing the requested report. " + f"You have access to the following tools: {', '.join(t.name for t in tools)}. " + f"For your reference, the current date is {current_date}. {instrument_context}\n\n" + f"{system_message}" + ) + return run_sdk_analyst( + llm=llm, + state=state, + system_prompt=full_system, + lc_tools=tools, + server_name="fundamentals_analyst", + report_field="fundamentals_report", + ) + chain = prompt | llm.bind_tools(tools) result = chain.invoke(state["messages"]) diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index fef8f751..0d7ee372 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -1,4 +1,5 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst from tradingagents.agents.utils.agent_utils import ( build_instrument_context, get_indicators, @@ -6,6 +7,7 @@ from tradingagents.agents.utils.agent_utils import ( get_stock_data, ) from tradingagents.dataflows.config import get_config +from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent def create_market_analyst(llm): @@ -71,6 +73,23 @@ Volume-Based Indicators: prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(instrument_context=instrument_context) + if isinstance(llm, ChatClaudeAgent): + full_system = ( + "You are a helpful AI assistant. Use the provided tools to progress towards " + "producing the requested report. " + f"You have access to the following tools: {', '.join(t.name for t in tools)}. " + f"For your reference, the current date is {current_date}. {instrument_context}\n\n" + f"{system_message}" + ) + return run_sdk_analyst( + llm=llm, + state=state, + system_prompt=full_system, + lc_tools=tools, + server_name="market_analyst", + report_field="market_report", + ) + chain = prompt | llm.bind_tools(tools) result = chain.invoke(state["messages"]) diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index e0fe93c5..94843ed7 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -1,4 +1,5 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst from tradingagents.agents.utils.agent_utils import ( build_instrument_context, get_global_news, @@ -6,6 +7,7 @@ from tradingagents.agents.utils.agent_utils import ( get_news, ) from tradingagents.dataflows.config import get_config +from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent def create_news_analyst(llm): @@ -46,6 +48,23 @@ def create_news_analyst(llm): prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(instrument_context=instrument_context) + if isinstance(llm, ChatClaudeAgent): + full_system = ( + "You are a helpful AI assistant. Use the provided tools to progress towards " + "producing the requested report. " + f"You have access to the following tools: {', '.join(t.name for t in tools)}. " + f"For your reference, the current date is {current_date}. {instrument_context}\n\n" + f"{system_message}" + ) + return run_sdk_analyst( + llm=llm, + state=state, + system_prompt=full_system, + lc_tools=tools, + server_name="news_analyst", + report_field="news_report", + ) + chain = prompt | llm.bind_tools(tools) result = chain.invoke(state["messages"]) diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index 34a53c46..417ac141 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -1,6 +1,8 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news from tradingagents.dataflows.config import get_config +from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent def create_social_media_analyst(llm): @@ -40,6 +42,23 @@ def create_social_media_analyst(llm): prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(instrument_context=instrument_context) + if isinstance(llm, ChatClaudeAgent): + full_system = ( + "You are a helpful AI assistant. Use the provided tools to progress towards " + "producing the requested report. " + f"You have access to the following tools: {', '.join(t.name for t in tools)}. " + f"For your reference, the current date is {current_date}. {instrument_context}\n\n" + f"{system_message}" + ) + return run_sdk_analyst( + llm=llm, + state=state, + system_prompt=full_system, + lc_tools=tools, + server_name="social_analyst", + report_field="sentiment_report", + ) + chain = prompt | llm.bind_tools(tools) result = chain.invoke(state["messages"]) diff --git a/tradingagents/llm_clients/mcp_tool_adapter.py b/tradingagents/llm_clients/mcp_tool_adapter.py new file mode 100644 index 00000000..bd74f1fd --- /dev/null +++ b/tradingagents/llm_clients/mcp_tool_adapter.py @@ -0,0 +1,48 @@ +"""Translate LangChain @tool-decorated functions into claude-agent-sdk MCP tools. + +LangChain tools expose .name, .description, a Pydantic args_schema, and a sync +.invoke({...}). The SDK wants an @tool-decorated async callable that takes a +dict and returns {"content": [{"type": "text", "text": str}]}. + +Used by the SDK-native analyst runner to let Claude Code (authenticated via a +Max/Pro subscription) call the same data tools the legacy analyst path uses. +""" + +from typing import Any, Dict, List, Tuple + +from claude_agent_sdk import create_sdk_mcp_server, tool + + +def _wrap_lc_tool(lc_tool: Any): + """Wrap a single LangChain BaseTool as an SDK @tool-decorated async callable.""" + schema = ( + lc_tool.args_schema.model_json_schema() + if lc_tool.args_schema is not None + else {"type": "object", "properties": {}} + ) + + @tool( + name=lc_tool.name, + description=lc_tool.description or lc_tool.name, + input_schema=schema, + ) + async def _wrapped(args: Dict[str, Any]) -> Dict[str, Any]: + result = lc_tool.invoke(args) + return {"content": [{"type": "text", "text": str(result)}]} + + return _wrapped + + +def build_mcp_server( + server_name: str, + lc_tools: List[Any], +) -> Tuple[Any, List[str]]: + """Build an in-process MCP server from LangChain tools. + + Returns the server instance and the list of fully-qualified tool names + (``mcp____``) suitable for passing to ``allowed_tools``. + """ + wrapped = [_wrap_lc_tool(t) for t in lc_tools] + server = create_sdk_mcp_server(name=server_name, version="1.0.0", tools=wrapped) + allowed = [f"mcp__{server_name}__{t.name}" for t in lc_tools] + return server, allowed