feat: route analyst tool loop through Claude Agent SDK (Shape B)

The 4 analysts (market, news, social, fundamentals) now detect a
ChatClaudeAgent LLM and dispatch to an SDK-native runner: LangChain @tool
functions are wrapped as in-process MCP tools via create_sdk_mcp_server, and
the SDK owns the iterative tool-calling loop. Claude returns the final report
in one call, so the analyst node outputs an AIMessage with no tool_calls and
the existing conditional edges route straight to the message-clear step.

Together with the Shape A provider this lets a Claude Max subscription drive
the full TradingAgents graph without an Anthropic API key. Other providers
continue to take the original bind_tools + LangGraph ToolNode path unchanged.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Michael Yang 2026-04-14 16:08:40 -04:00
parent 4c1879d9f2
commit 2f870be9a8
7 changed files with 275 additions and 0 deletions

View File

@ -0,0 +1,51 @@
"""Shape B smoke test: run the market analyst end-to-end against a ticker/date
via the claude_agent provider, confirming MCP tool translation and the SDK-
native tool loop.
Requires the user to be logged into Claude Code.
"""
import time
from langchain_core.messages import HumanMessage
from tradingagents.agents.analysts.market_analyst import create_market_analyst
from tradingagents.llm_clients.factory import create_llm_client
def main():
client = create_llm_client(provider="claude_agent", model="sonnet")
llm = client.get_llm()
node = create_market_analyst(llm)
state = {
"trade_date": "2025-10-15",
"company_of_interest": "AAPL",
"messages": [
HumanMessage(
content=(
"Produce a concise market analysis report for AAPL based on "
"the most recent price data and a few key technical indicators. "
"Keep it under 500 words."
)
)
],
}
print(f"Running market analyst on {state['company_of_interest']} @ {state['trade_date']}...")
start = time.monotonic()
output = node(state)
elapsed = time.monotonic() - start
report = output.get("market_report", "")
print(f"\n--- market_report ({elapsed:.1f}s, {len(report)} chars) ---\n")
print(report)
assert report, "market_report is empty"
assert "messages" in output and len(output["messages"]) == 1
print("\nShape B smoke test OK.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,100 @@
"""SDK-native analyst runner.
When the configured LLM is :class:`ChatClaudeAgent`, the analyst node delegates
the whole tool-calling loop to ``claude-agent-sdk``. The SDK owns the loop:
Claude iteratively invokes the translated MCP tools and returns a final text
report. No LangGraph ToolNode involvement the analyst returns a terminal
AIMessage with zero tool_calls, so the existing conditional edges route
straight to the message-clear node.
"""
import asyncio
from typing import Any, Dict, List
from langchain_core.messages import AIMessage, HumanMessage
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
from tradingagents.llm_clients.mcp_tool_adapter import build_mcp_server
def _build_user_prompt(state: Dict[str, Any]) -> str:
"""Extract any human content from the incoming message sequence.
Existing analysts rely on LangGraph feeding tool-call round trips through
state["messages"]. On the SDK path we collapse the incoming messages into a
single user prompt tool results are consumed by the SDK loop, not via
LangGraph, so only the human-authored content matters here.
"""
parts: List[str] = []
for msg in state.get("messages", []):
if isinstance(msg, HumanMessage):
content = msg.content
if isinstance(content, str) and content.strip():
parts.append(content.strip())
if not parts:
parts.append("Produce the requested report.")
return "\n\n".join(parts)
async def _run(
system_prompt: str,
user_prompt: str,
lc_tools: List[Any],
server_name: str,
model: str,
) -> str:
from claude_agent_sdk import (
AssistantMessage,
ClaudeAgentOptions,
TextBlock,
query,
)
server, allowed = build_mcp_server(server_name, lc_tools)
options = ClaudeAgentOptions(
model=model,
system_prompt=system_prompt,
mcp_servers={server_name: server},
allowed_tools=allowed,
# Block the Claude Code built-ins; only our MCP tools should run.
disallowed_tools=[
"Bash", "Read", "Write", "Edit", "MultiEdit",
"Glob", "Grep", "WebFetch", "WebSearch",
"Task", "TodoWrite", "NotebookEdit",
],
permission_mode="bypassPermissions",
)
text_parts: List[str] = []
async for msg in query(prompt=user_prompt, options=options):
if isinstance(msg, AssistantMessage):
for block in msg.content:
if isinstance(block, TextBlock):
text_parts.append(block.text)
return "\n".join(text_parts).strip()
def run_sdk_analyst(
llm: ChatClaudeAgent,
state: Dict[str, Any],
system_prompt: str,
lc_tools: List[Any],
server_name: str,
report_field: str,
) -> Dict[str, Any]:
"""Run an analyst through the Claude Agent SDK tool loop and build the node output."""
user_prompt = _build_user_prompt(state)
report = asyncio.run(
_run(
system_prompt=system_prompt,
user_prompt=user_prompt,
lc_tools=lc_tools,
server_name=server_name,
model=llm.model,
)
)
return {
"messages": [AIMessage(content=report)],
report_field: report,
}

View File

@ -1,4 +1,5 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_balance_sheet,
@ -9,6 +10,7 @@ from tradingagents.agents.utils.agent_utils import (
get_language_instruction,
)
from tradingagents.dataflows.config import get_config
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
def create_fundamentals_analyst(llm):
@ -52,6 +54,23 @@ def create_fundamentals_analyst(llm):
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(instrument_context=instrument_context)
if isinstance(llm, ChatClaudeAgent):
full_system = (
"You are a helpful AI assistant. Use the provided tools to progress towards "
"producing the requested report. "
f"You have access to the following tools: {', '.join(t.name for t in tools)}. "
f"For your reference, the current date is {current_date}. {instrument_context}\n\n"
f"{system_message}"
)
return run_sdk_analyst(
llm=llm,
state=state,
system_prompt=full_system,
lc_tools=tools,
server_name="fundamentals_analyst",
report_field="fundamentals_report",
)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])

View File

@ -1,4 +1,5 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_indicators,
@ -6,6 +7,7 @@ from tradingagents.agents.utils.agent_utils import (
get_stock_data,
)
from tradingagents.dataflows.config import get_config
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
def create_market_analyst(llm):
@ -71,6 +73,23 @@ Volume-Based Indicators:
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(instrument_context=instrument_context)
if isinstance(llm, ChatClaudeAgent):
full_system = (
"You are a helpful AI assistant. Use the provided tools to progress towards "
"producing the requested report. "
f"You have access to the following tools: {', '.join(t.name for t in tools)}. "
f"For your reference, the current date is {current_date}. {instrument_context}\n\n"
f"{system_message}"
)
return run_sdk_analyst(
llm=llm,
state=state,
system_prompt=full_system,
lc_tools=tools,
server_name="market_analyst",
report_field="market_report",
)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])

View File

@ -1,4 +1,5 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_global_news,
@ -6,6 +7,7 @@ from tradingagents.agents.utils.agent_utils import (
get_news,
)
from tradingagents.dataflows.config import get_config
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
def create_news_analyst(llm):
@ -46,6 +48,23 @@ def create_news_analyst(llm):
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(instrument_context=instrument_context)
if isinstance(llm, ChatClaudeAgent):
full_system = (
"You are a helpful AI assistant. Use the provided tools to progress towards "
"producing the requested report. "
f"You have access to the following tools: {', '.join(t.name for t in tools)}. "
f"For your reference, the current date is {current_date}. {instrument_context}\n\n"
f"{system_message}"
)
return run_sdk_analyst(
llm=llm,
state=state,
system_prompt=full_system,
lc_tools=tools,
server_name="news_analyst",
report_field="news_report",
)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])

View File

@ -1,6 +1,8 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
from tradingagents.dataflows.config import get_config
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
def create_social_media_analyst(llm):
@ -40,6 +42,23 @@ def create_social_media_analyst(llm):
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(instrument_context=instrument_context)
if isinstance(llm, ChatClaudeAgent):
full_system = (
"You are a helpful AI assistant. Use the provided tools to progress towards "
"producing the requested report. "
f"You have access to the following tools: {', '.join(t.name for t in tools)}. "
f"For your reference, the current date is {current_date}. {instrument_context}\n\n"
f"{system_message}"
)
return run_sdk_analyst(
llm=llm,
state=state,
system_prompt=full_system,
lc_tools=tools,
server_name="social_analyst",
report_field="sentiment_report",
)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])

View File

@ -0,0 +1,48 @@
"""Translate LangChain @tool-decorated functions into claude-agent-sdk MCP tools.
LangChain tools expose .name, .description, a Pydantic args_schema, and a sync
.invoke({...}). The SDK wants an @tool-decorated async callable that takes a
dict and returns {"content": [{"type": "text", "text": str}]}.
Used by the SDK-native analyst runner to let Claude Code (authenticated via a
Max/Pro subscription) call the same data tools the legacy analyst path uses.
"""
from typing import Any, Dict, List, Tuple
from claude_agent_sdk import create_sdk_mcp_server, tool
def _wrap_lc_tool(lc_tool: Any):
"""Wrap a single LangChain BaseTool as an SDK @tool-decorated async callable."""
schema = (
lc_tool.args_schema.model_json_schema()
if lc_tool.args_schema is not None
else {"type": "object", "properties": {}}
)
@tool(
name=lc_tool.name,
description=lc_tool.description or lc_tool.name,
input_schema=schema,
)
async def _wrapped(args: Dict[str, Any]) -> Dict[str, Any]:
result = lc_tool.invoke(args)
return {"content": [{"type": "text", "text": str(result)}]}
return _wrapped
def build_mcp_server(
server_name: str,
lc_tools: List[Any],
) -> Tuple[Any, List[str]]:
"""Build an in-process MCP server from LangChain tools.
Returns the server instance and the list of fully-qualified tool names
(``mcp__<server>__<tool>``) suitable for passing to ``allowed_tools``.
"""
wrapped = [_wrap_lc_tool(t) for t in lc_tools]
server = create_sdk_mcp_server(name=server_name, version="1.0.0", tools=wrapped)
allowed = [f"mcp__{server_name}__{t.name}" for t in lc_tools]
return server, allowed