feat: route analyst tool loop through Claude Agent SDK (Shape B)
The 4 analysts (market, news, social, fundamentals) now detect a ChatClaudeAgent LLM and dispatch to an SDK-native runner: LangChain @tool functions are wrapped as in-process MCP tools via create_sdk_mcp_server, and the SDK owns the iterative tool-calling loop. Claude returns the final report in one call, so the analyst node outputs an AIMessage with no tool_calls and the existing conditional edges route straight to the message-clear step. Together with the Shape A provider this lets a Claude Max subscription drive the full TradingAgents graph without an Anthropic API key. Other providers continue to take the original bind_tools + LangGraph ToolNode path unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
4c1879d9f2
commit
2f870be9a8
|
|
@ -0,0 +1,51 @@
|
|||
"""Shape B smoke test: run the market analyst end-to-end against a ticker/date
|
||||
via the claude_agent provider, confirming MCP tool translation and the SDK-
|
||||
native tool loop.
|
||||
|
||||
Requires the user to be logged into Claude Code.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from tradingagents.agents.analysts.market_analyst import create_market_analyst
|
||||
from tradingagents.llm_clients.factory import create_llm_client
|
||||
|
||||
|
||||
def main():
|
||||
client = create_llm_client(provider="claude_agent", model="sonnet")
|
||||
llm = client.get_llm()
|
||||
|
||||
node = create_market_analyst(llm)
|
||||
|
||||
state = {
|
||||
"trade_date": "2025-10-15",
|
||||
"company_of_interest": "AAPL",
|
||||
"messages": [
|
||||
HumanMessage(
|
||||
content=(
|
||||
"Produce a concise market analysis report for AAPL based on "
|
||||
"the most recent price data and a few key technical indicators. "
|
||||
"Keep it under 500 words."
|
||||
)
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
print(f"Running market analyst on {state['company_of_interest']} @ {state['trade_date']}...")
|
||||
start = time.monotonic()
|
||||
output = node(state)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
report = output.get("market_report", "")
|
||||
print(f"\n--- market_report ({elapsed:.1f}s, {len(report)} chars) ---\n")
|
||||
print(report)
|
||||
|
||||
assert report, "market_report is empty"
|
||||
assert "messages" in output and len(output["messages"]) == 1
|
||||
print("\nShape B smoke test OK.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
"""SDK-native analyst runner.
|
||||
|
||||
When the configured LLM is :class:`ChatClaudeAgent`, the analyst node delegates
|
||||
the whole tool-calling loop to ``claude-agent-sdk``. The SDK owns the loop:
|
||||
Claude iteratively invokes the translated MCP tools and returns a final text
|
||||
report. No LangGraph ToolNode involvement — the analyst returns a terminal
|
||||
AIMessage with zero tool_calls, so the existing conditional edges route
|
||||
straight to the message-clear node.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
|
||||
from tradingagents.llm_clients.mcp_tool_adapter import build_mcp_server
|
||||
|
||||
|
||||
def _build_user_prompt(state: Dict[str, Any]) -> str:
|
||||
"""Extract any human content from the incoming message sequence.
|
||||
|
||||
Existing analysts rely on LangGraph feeding tool-call round trips through
|
||||
state["messages"]. On the SDK path we collapse the incoming messages into a
|
||||
single user prompt — tool results are consumed by the SDK loop, not via
|
||||
LangGraph, so only the human-authored content matters here.
|
||||
"""
|
||||
parts: List[str] = []
|
||||
for msg in state.get("messages", []):
|
||||
if isinstance(msg, HumanMessage):
|
||||
content = msg.content
|
||||
if isinstance(content, str) and content.strip():
|
||||
parts.append(content.strip())
|
||||
if not parts:
|
||||
parts.append("Produce the requested report.")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
async def _run(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
lc_tools: List[Any],
|
||||
server_name: str,
|
||||
model: str,
|
||||
) -> str:
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ClaudeAgentOptions,
|
||||
TextBlock,
|
||||
query,
|
||||
)
|
||||
|
||||
server, allowed = build_mcp_server(server_name, lc_tools)
|
||||
|
||||
options = ClaudeAgentOptions(
|
||||
model=model,
|
||||
system_prompt=system_prompt,
|
||||
mcp_servers={server_name: server},
|
||||
allowed_tools=allowed,
|
||||
# Block the Claude Code built-ins; only our MCP tools should run.
|
||||
disallowed_tools=[
|
||||
"Bash", "Read", "Write", "Edit", "MultiEdit",
|
||||
"Glob", "Grep", "WebFetch", "WebSearch",
|
||||
"Task", "TodoWrite", "NotebookEdit",
|
||||
],
|
||||
permission_mode="bypassPermissions",
|
||||
)
|
||||
|
||||
text_parts: List[str] = []
|
||||
async for msg in query(prompt=user_prompt, options=options):
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
text_parts.append(block.text)
|
||||
return "\n".join(text_parts).strip()
|
||||
|
||||
|
||||
def run_sdk_analyst(
|
||||
llm: ChatClaudeAgent,
|
||||
state: Dict[str, Any],
|
||||
system_prompt: str,
|
||||
lc_tools: List[Any],
|
||||
server_name: str,
|
||||
report_field: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run an analyst through the Claude Agent SDK tool loop and build the node output."""
|
||||
user_prompt = _build_user_prompt(state)
|
||||
report = asyncio.run(
|
||||
_run(
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
lc_tools=lc_tools,
|
||||
server_name=server_name,
|
||||
model=llm.model,
|
||||
)
|
||||
)
|
||||
return {
|
||||
"messages": [AIMessage(content=report)],
|
||||
report_field: report,
|
||||
}
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_balance_sheet,
|
||||
|
|
@ -9,6 +10,7 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_language_instruction,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
|
||||
|
||||
|
||||
def create_fundamentals_analyst(llm):
|
||||
|
|
@ -52,6 +54,23 @@ def create_fundamentals_analyst(llm):
|
|||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
|
||||
if isinstance(llm, ChatClaudeAgent):
|
||||
full_system = (
|
||||
"You are a helpful AI assistant. Use the provided tools to progress towards "
|
||||
"producing the requested report. "
|
||||
f"You have access to the following tools: {', '.join(t.name for t in tools)}. "
|
||||
f"For your reference, the current date is {current_date}. {instrument_context}\n\n"
|
||||
f"{system_message}"
|
||||
)
|
||||
return run_sdk_analyst(
|
||||
llm=llm,
|
||||
state=state,
|
||||
system_prompt=full_system,
|
||||
lc_tools=tools,
|
||||
server_name="fundamentals_analyst",
|
||||
report_field="fundamentals_report",
|
||||
)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_indicators,
|
||||
|
|
@ -6,6 +7,7 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_stock_data,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
|
||||
|
||||
|
||||
def create_market_analyst(llm):
|
||||
|
|
@ -71,6 +73,23 @@ Volume-Based Indicators:
|
|||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
|
||||
if isinstance(llm, ChatClaudeAgent):
|
||||
full_system = (
|
||||
"You are a helpful AI assistant. Use the provided tools to progress towards "
|
||||
"producing the requested report. "
|
||||
f"You have access to the following tools: {', '.join(t.name for t in tools)}. "
|
||||
f"For your reference, the current date is {current_date}. {instrument_context}\n\n"
|
||||
f"{system_message}"
|
||||
)
|
||||
return run_sdk_analyst(
|
||||
llm=llm,
|
||||
state=state,
|
||||
system_prompt=full_system,
|
||||
lc_tools=tools,
|
||||
server_name="market_analyst",
|
||||
report_field="market_report",
|
||||
)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_instrument_context,
|
||||
get_global_news,
|
||||
|
|
@ -6,6 +7,7 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_news,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
|
||||
|
||||
|
||||
def create_news_analyst(llm):
|
||||
|
|
@ -46,6 +48,23 @@ def create_news_analyst(llm):
|
|||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
|
||||
if isinstance(llm, ChatClaudeAgent):
|
||||
full_system = (
|
||||
"You are a helpful AI assistant. Use the provided tools to progress towards "
|
||||
"producing the requested report. "
|
||||
f"You have access to the following tools: {', '.join(t.name for t in tools)}. "
|
||||
f"For your reference, the current date is {current_date}. {instrument_context}\n\n"
|
||||
f"{system_message}"
|
||||
)
|
||||
return run_sdk_analyst(
|
||||
llm=llm,
|
||||
state=state,
|
||||
system_prompt=full_system,
|
||||
lc_tools=tools,
|
||||
server_name="news_analyst",
|
||||
report_field="news_report",
|
||||
)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
|
||||
from tradingagents.dataflows.config import get_config
|
||||
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
|
||||
|
||||
|
||||
def create_social_media_analyst(llm):
|
||||
|
|
@ -40,6 +42,23 @@ def create_social_media_analyst(llm):
|
|||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
|
||||
if isinstance(llm, ChatClaudeAgent):
|
||||
full_system = (
|
||||
"You are a helpful AI assistant. Use the provided tools to progress towards "
|
||||
"producing the requested report. "
|
||||
f"You have access to the following tools: {', '.join(t.name for t in tools)}. "
|
||||
f"For your reference, the current date is {current_date}. {instrument_context}\n\n"
|
||||
f"{system_message}"
|
||||
)
|
||||
return run_sdk_analyst(
|
||||
llm=llm,
|
||||
state=state,
|
||||
system_prompt=full_system,
|
||||
lc_tools=tools,
|
||||
server_name="social_analyst",
|
||||
report_field="sentiment_report",
|
||||
)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
"""Translate LangChain @tool-decorated functions into claude-agent-sdk MCP tools.
|
||||
|
||||
LangChain tools expose .name, .description, a Pydantic args_schema, and a sync
|
||||
.invoke({...}). The SDK wants an @tool-decorated async callable that takes a
|
||||
dict and returns {"content": [{"type": "text", "text": str}]}.
|
||||
|
||||
Used by the SDK-native analyst runner to let Claude Code (authenticated via a
|
||||
Max/Pro subscription) call the same data tools the legacy analyst path uses.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
|
||||
def _wrap_lc_tool(lc_tool: Any):
|
||||
"""Wrap a single LangChain BaseTool as an SDK @tool-decorated async callable."""
|
||||
schema = (
|
||||
lc_tool.args_schema.model_json_schema()
|
||||
if lc_tool.args_schema is not None
|
||||
else {"type": "object", "properties": {}}
|
||||
)
|
||||
|
||||
@tool(
|
||||
name=lc_tool.name,
|
||||
description=lc_tool.description or lc_tool.name,
|
||||
input_schema=schema,
|
||||
)
|
||||
async def _wrapped(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = lc_tool.invoke(args)
|
||||
return {"content": [{"type": "text", "text": str(result)}]}
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
def build_mcp_server(
|
||||
server_name: str,
|
||||
lc_tools: List[Any],
|
||||
) -> Tuple[Any, List[str]]:
|
||||
"""Build an in-process MCP server from LangChain tools.
|
||||
|
||||
Returns the server instance and the list of fully-qualified tool names
|
||||
(``mcp__<server>__<tool>``) suitable for passing to ``allowed_tools``.
|
||||
"""
|
||||
wrapped = [_wrap_lc_tool(t) for t in lc_tools]
|
||||
server = create_sdk_mcp_server(name=server_name, version="1.0.0", tools=wrapped)
|
||||
allowed = [f"mcp__{server_name}__{t.name}" for t in lc_tools]
|
||||
return server, allowed
|
||||
Loading…
Reference in New Issue