feat: route analyst tool loop through Claude Agent SDK (Shape B)
The 4 analysts (market, news, social, fundamentals) now detect a ChatClaudeAgent LLM and dispatch to an SDK-native runner: LangChain @tool functions are wrapped as in-process MCP tools via create_sdk_mcp_server, and the SDK owns the iterative tool-calling loop. Claude returns the final report in one call, so the analyst node outputs an AIMessage with no tool_calls and the existing conditional edges route straight to the message-clear step. Together with the Shape A provider this lets a Claude Max subscription drive the full TradingAgents graph without an Anthropic API key. Other providers continue to take the original bind_tools + LangGraph ToolNode path unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
4c1879d9f2
commit
2f870be9a8
|
|
@ -0,0 +1,51 @@
|
||||||
|
"""Shape B smoke test: run the market analyst end-to-end against a ticker/date
|
||||||
|
via the claude_agent provider, confirming MCP tool translation and the SDK-
|
||||||
|
native tool loop.
|
||||||
|
|
||||||
|
Requires the user to be logged into Claude Code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from tradingagents.agents.analysts.market_analyst import create_market_analyst
|
||||||
|
from tradingagents.llm_clients.factory import create_llm_client
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
client = create_llm_client(provider="claude_agent", model="sonnet")
|
||||||
|
llm = client.get_llm()
|
||||||
|
|
||||||
|
node = create_market_analyst(llm)
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"trade_date": "2025-10-15",
|
||||||
|
"company_of_interest": "AAPL",
|
||||||
|
"messages": [
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
"Produce a concise market analysis report for AAPL based on "
|
||||||
|
"the most recent price data and a few key technical indicators. "
|
||||||
|
"Keep it under 500 words."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Running market analyst on {state['company_of_interest']} @ {state['trade_date']}...")
|
||||||
|
start = time.monotonic()
|
||||||
|
output = node(state)
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
|
||||||
|
report = output.get("market_report", "")
|
||||||
|
print(f"\n--- market_report ({elapsed:.1f}s, {len(report)} chars) ---\n")
|
||||||
|
print(report)
|
||||||
|
|
||||||
|
assert report, "market_report is empty"
|
||||||
|
assert "messages" in output and len(output["messages"]) == 1
|
||||||
|
print("\nShape B smoke test OK.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,100 @@
|
||||||
|
"""SDK-native analyst runner.
|
||||||
|
|
||||||
|
When the configured LLM is :class:`ChatClaudeAgent`, the analyst node delegates
|
||||||
|
the whole tool-calling loop to ``claude-agent-sdk``. The SDK owns the loop:
|
||||||
|
Claude iteratively invokes the translated MCP tools and returns a final text
|
||||||
|
report. No LangGraph ToolNode involvement — the analyst returns a terminal
|
||||||
|
AIMessage with zero tool_calls, so the existing conditional edges route
|
||||||
|
straight to the message-clear node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
|
||||||
|
from tradingagents.llm_clients.mcp_tool_adapter import build_mcp_server
|
||||||
|
|
||||||
|
|
||||||
|
def _build_user_prompt(state: Dict[str, Any]) -> str:
|
||||||
|
"""Extract any human content from the incoming message sequence.
|
||||||
|
|
||||||
|
Existing analysts rely on LangGraph feeding tool-call round trips through
|
||||||
|
state["messages"]. On the SDK path we collapse the incoming messages into a
|
||||||
|
single user prompt — tool results are consumed by the SDK loop, not via
|
||||||
|
LangGraph, so only the human-authored content matters here.
|
||||||
|
"""
|
||||||
|
parts: List[str] = []
|
||||||
|
for msg in state.get("messages", []):
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
content = msg.content
|
||||||
|
if isinstance(content, str) and content.strip():
|
||||||
|
parts.append(content.strip())
|
||||||
|
if not parts:
|
||||||
|
parts.append("Produce the requested report.")
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run(
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
lc_tools: List[Any],
|
||||||
|
server_name: str,
|
||||||
|
model: str,
|
||||||
|
) -> str:
|
||||||
|
from claude_agent_sdk import (
|
||||||
|
AssistantMessage,
|
||||||
|
ClaudeAgentOptions,
|
||||||
|
TextBlock,
|
||||||
|
query,
|
||||||
|
)
|
||||||
|
|
||||||
|
server, allowed = build_mcp_server(server_name, lc_tools)
|
||||||
|
|
||||||
|
options = ClaudeAgentOptions(
|
||||||
|
model=model,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
mcp_servers={server_name: server},
|
||||||
|
allowed_tools=allowed,
|
||||||
|
# Block the Claude Code built-ins; only our MCP tools should run.
|
||||||
|
disallowed_tools=[
|
||||||
|
"Bash", "Read", "Write", "Edit", "MultiEdit",
|
||||||
|
"Glob", "Grep", "WebFetch", "WebSearch",
|
||||||
|
"Task", "TodoWrite", "NotebookEdit",
|
||||||
|
],
|
||||||
|
permission_mode="bypassPermissions",
|
||||||
|
)
|
||||||
|
|
||||||
|
text_parts: List[str] = []
|
||||||
|
async for msg in query(prompt=user_prompt, options=options):
|
||||||
|
if isinstance(msg, AssistantMessage):
|
||||||
|
for block in msg.content:
|
||||||
|
if isinstance(block, TextBlock):
|
||||||
|
text_parts.append(block.text)
|
||||||
|
return "\n".join(text_parts).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def run_sdk_analyst(
|
||||||
|
llm: ChatClaudeAgent,
|
||||||
|
state: Dict[str, Any],
|
||||||
|
system_prompt: str,
|
||||||
|
lc_tools: List[Any],
|
||||||
|
server_name: str,
|
||||||
|
report_field: str,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Run an analyst through the Claude Agent SDK tool loop and build the node output."""
|
||||||
|
user_prompt = _build_user_prompt(state)
|
||||||
|
report = asyncio.run(
|
||||||
|
_run(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
lc_tools=lc_tools,
|
||||||
|
server_name=server_name,
|
||||||
|
model=llm.model,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"messages": [AIMessage(content=report)],
|
||||||
|
report_field: report,
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
build_instrument_context,
|
||||||
get_balance_sheet,
|
get_balance_sheet,
|
||||||
|
|
@ -9,6 +10,7 @@ from tradingagents.agents.utils.agent_utils import (
|
||||||
get_language_instruction,
|
get_language_instruction,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
|
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
|
||||||
|
|
||||||
|
|
||||||
def create_fundamentals_analyst(llm):
|
def create_fundamentals_analyst(llm):
|
||||||
|
|
@ -52,6 +54,23 @@ def create_fundamentals_analyst(llm):
|
||||||
prompt = prompt.partial(current_date=current_date)
|
prompt = prompt.partial(current_date=current_date)
|
||||||
prompt = prompt.partial(instrument_context=instrument_context)
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
|
if isinstance(llm, ChatClaudeAgent):
|
||||||
|
full_system = (
|
||||||
|
"You are a helpful AI assistant. Use the provided tools to progress towards "
|
||||||
|
"producing the requested report. "
|
||||||
|
f"You have access to the following tools: {', '.join(t.name for t in tools)}. "
|
||||||
|
f"For your reference, the current date is {current_date}. {instrument_context}\n\n"
|
||||||
|
f"{system_message}"
|
||||||
|
)
|
||||||
|
return run_sdk_analyst(
|
||||||
|
llm=llm,
|
||||||
|
state=state,
|
||||||
|
system_prompt=full_system,
|
||||||
|
lc_tools=tools,
|
||||||
|
server_name="fundamentals_analyst",
|
||||||
|
report_field="fundamentals_report",
|
||||||
|
)
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
chain = prompt | llm.bind_tools(tools)
|
||||||
|
|
||||||
result = chain.invoke(state["messages"])
|
result = chain.invoke(state["messages"])
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
build_instrument_context,
|
||||||
get_indicators,
|
get_indicators,
|
||||||
|
|
@ -6,6 +7,7 @@ from tradingagents.agents.utils.agent_utils import (
|
||||||
get_stock_data,
|
get_stock_data,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
|
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
|
||||||
|
|
||||||
|
|
||||||
def create_market_analyst(llm):
|
def create_market_analyst(llm):
|
||||||
|
|
@ -71,6 +73,23 @@ Volume-Based Indicators:
|
||||||
prompt = prompt.partial(current_date=current_date)
|
prompt = prompt.partial(current_date=current_date)
|
||||||
prompt = prompt.partial(instrument_context=instrument_context)
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
|
if isinstance(llm, ChatClaudeAgent):
|
||||||
|
full_system = (
|
||||||
|
"You are a helpful AI assistant. Use the provided tools to progress towards "
|
||||||
|
"producing the requested report. "
|
||||||
|
f"You have access to the following tools: {', '.join(t.name for t in tools)}. "
|
||||||
|
f"For your reference, the current date is {current_date}. {instrument_context}\n\n"
|
||||||
|
f"{system_message}"
|
||||||
|
)
|
||||||
|
return run_sdk_analyst(
|
||||||
|
llm=llm,
|
||||||
|
state=state,
|
||||||
|
system_prompt=full_system,
|
||||||
|
lc_tools=tools,
|
||||||
|
server_name="market_analyst",
|
||||||
|
report_field="market_report",
|
||||||
|
)
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
chain = prompt | llm.bind_tools(tools)
|
||||||
|
|
||||||
result = chain.invoke(state["messages"])
|
result = chain.invoke(state["messages"])
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
build_instrument_context,
|
build_instrument_context,
|
||||||
get_global_news,
|
get_global_news,
|
||||||
|
|
@ -6,6 +7,7 @@ from tradingagents.agents.utils.agent_utils import (
|
||||||
get_news,
|
get_news,
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
|
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
|
||||||
|
|
||||||
|
|
||||||
def create_news_analyst(llm):
|
def create_news_analyst(llm):
|
||||||
|
|
@ -46,6 +48,23 @@ def create_news_analyst(llm):
|
||||||
prompt = prompt.partial(current_date=current_date)
|
prompt = prompt.partial(current_date=current_date)
|
||||||
prompt = prompt.partial(instrument_context=instrument_context)
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
|
if isinstance(llm, ChatClaudeAgent):
|
||||||
|
full_system = (
|
||||||
|
"You are a helpful AI assistant. Use the provided tools to progress towards "
|
||||||
|
"producing the requested report. "
|
||||||
|
f"You have access to the following tools: {', '.join(t.name for t in tools)}. "
|
||||||
|
f"For your reference, the current date is {current_date}. {instrument_context}\n\n"
|
||||||
|
f"{system_message}"
|
||||||
|
)
|
||||||
|
return run_sdk_analyst(
|
||||||
|
llm=llm,
|
||||||
|
state=state,
|
||||||
|
system_prompt=full_system,
|
||||||
|
lc_tools=tools,
|
||||||
|
server_name="news_analyst",
|
||||||
|
report_field="news_report",
|
||||||
|
)
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
chain = prompt | llm.bind_tools(tools)
|
||||||
result = chain.invoke(state["messages"])
|
result = chain.invoke(state["messages"])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
from tradingagents.agents.analysts._claude_agent_runner import run_sdk_analyst
|
||||||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
|
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
|
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
|
||||||
|
|
||||||
|
|
||||||
def create_social_media_analyst(llm):
|
def create_social_media_analyst(llm):
|
||||||
|
|
@ -40,6 +42,23 @@ def create_social_media_analyst(llm):
|
||||||
prompt = prompt.partial(current_date=current_date)
|
prompt = prompt.partial(current_date=current_date)
|
||||||
prompt = prompt.partial(instrument_context=instrument_context)
|
prompt = prompt.partial(instrument_context=instrument_context)
|
||||||
|
|
||||||
|
if isinstance(llm, ChatClaudeAgent):
|
||||||
|
full_system = (
|
||||||
|
"You are a helpful AI assistant. Use the provided tools to progress towards "
|
||||||
|
"producing the requested report. "
|
||||||
|
f"You have access to the following tools: {', '.join(t.name for t in tools)}. "
|
||||||
|
f"For your reference, the current date is {current_date}. {instrument_context}\n\n"
|
||||||
|
f"{system_message}"
|
||||||
|
)
|
||||||
|
return run_sdk_analyst(
|
||||||
|
llm=llm,
|
||||||
|
state=state,
|
||||||
|
system_prompt=full_system,
|
||||||
|
lc_tools=tools,
|
||||||
|
server_name="social_analyst",
|
||||||
|
report_field="sentiment_report",
|
||||||
|
)
|
||||||
|
|
||||||
chain = prompt | llm.bind_tools(tools)
|
chain = prompt | llm.bind_tools(tools)
|
||||||
|
|
||||||
result = chain.invoke(state["messages"])
|
result = chain.invoke(state["messages"])
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
"""Translate LangChain @tool-decorated functions into claude-agent-sdk MCP tools.
|
||||||
|
|
||||||
|
LangChain tools expose .name, .description, a Pydantic args_schema, and a sync
|
||||||
|
.invoke({...}). The SDK wants an @tool-decorated async callable that takes a
|
||||||
|
dict and returns {"content": [{"type": "text", "text": str}]}.
|
||||||
|
|
||||||
|
Used by the SDK-native analyst runner to let Claude Code (authenticated via a
|
||||||
|
Max/Pro subscription) call the same data tools the legacy analyst path uses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_lc_tool(lc_tool: Any):
|
||||||
|
"""Wrap a single LangChain BaseTool as an SDK @tool-decorated async callable."""
|
||||||
|
schema = (
|
||||||
|
lc_tool.args_schema.model_json_schema()
|
||||||
|
if lc_tool.args_schema is not None
|
||||||
|
else {"type": "object", "properties": {}}
|
||||||
|
)
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name=lc_tool.name,
|
||||||
|
description=lc_tool.description or lc_tool.name,
|
||||||
|
input_schema=schema,
|
||||||
|
)
|
||||||
|
async def _wrapped(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
result = lc_tool.invoke(args)
|
||||||
|
return {"content": [{"type": "text", "text": str(result)}]}
|
||||||
|
|
||||||
|
return _wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def build_mcp_server(
|
||||||
|
server_name: str,
|
||||||
|
lc_tools: List[Any],
|
||||||
|
) -> Tuple[Any, List[str]]:
|
||||||
|
"""Build an in-process MCP server from LangChain tools.
|
||||||
|
|
||||||
|
Returns the server instance and the list of fully-qualified tool names
|
||||||
|
(``mcp__<server>__<tool>``) suitable for passing to ``allowed_tools``.
|
||||||
|
"""
|
||||||
|
wrapped = [_wrap_lc_tool(t) for t in lc_tools]
|
||||||
|
server = create_sdk_mcp_server(name=server_name, version="1.0.0", tools=wrapped)
|
||||||
|
allowed = [f"mcp__{server_name}__{t.name}" for t in lc_tools]
|
||||||
|
return server, allowed
|
||||||
Loading…
Reference in New Issue