From 9a0bbb8294753faa4dd1dde61479ed398246f162 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 14 Apr 2026 15:53:26 -0400 Subject: [PATCH] feat: add claude_agent LLM provider backed by Claude Agent SDK (Shape A) Routes inference through Claude Code's OAuth session, so a Claude Max/Pro subscription authenticates without an Anthropic API key. Shape A supports plain .invoke() for prompt-only call sites (researchers, managers, trader, reflection, signal processing); bind_tools raises NotImplementedError until Shape B rewrites analysts to use the SDK's native tool loop. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/smoke_claude_agent.py | 66 +++++++++ .../llm_clients/claude_agent_client.py | 134 ++++++++++++++++++ tradingagents/llm_clients/factory.py | 4 + 3 files changed, 204 insertions(+) create mode 100644 scripts/smoke_claude_agent.py create mode 100644 tradingagents/llm_clients/claude_agent_client.py diff --git a/scripts/smoke_claude_agent.py b/scripts/smoke_claude_agent.py new file mode 100644 index 00000000..1819c0db --- /dev/null +++ b/scripts/smoke_claude_agent.py @@ -0,0 +1,66 @@ +"""Smoke test for the claude_agent provider (Shape A). + +Exercises the three call patterns used in TradingAgents: + 1. Plain string prompt (trader, researchers) + 2. List of role/content dicts (trader) + 3. List of LangChain messages via ChatPromptTemplate (simulated manager path) + +Prints the response text and timing for each. Any exception = failure. +Requires the user to be logged into Claude Code. +""" + +import time + +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + +from tradingagents.llm_clients.factory import create_llm_client + + +def run(label, invocation): + start = time.monotonic() + result = invocation() + elapsed = time.monotonic() - start + content = result.content if hasattr(result, "content") else result + preview = (content[:300] + "…") if len(content) > 300 else content + print(f"\n--- {label} ({elapsed:.1f}s) ---") + print(preview) + + +def main(): + client = create_llm_client( + provider="claude_agent", + model="sonnet", + ) + llm = client.get_llm() + + run("1. string prompt", + lambda: llm.invoke("In one sentence, what is a P/E ratio?")) + + run("2. role/content dicts", + lambda: llm.invoke([ + {"role": "system", "content": "You are a concise financial analyst."}, + {"role": "user", "content": "In one sentence, define moving average."}, + ])) + + prompt = ChatPromptTemplate.from_messages([ + ("system", "You are a concise financial analyst."), + MessagesPlaceholder(variable_name="messages"), + ]) + run("3. ChatPromptTemplate pipe", + lambda: (prompt | llm).invoke({ + "messages": [HumanMessage(content="In one sentence, define RSI.")] + })) + + try: + llm.bind_tools([]) + except NotImplementedError as e: + print(f"\n--- 4. bind_tools guard ---\nRaised as expected: {e}") + else: + raise AssertionError("bind_tools should have raised NotImplementedError") + + print("\nSmoke test OK.") + + +if __name__ == "__main__": + main() diff --git a/tradingagents/llm_clients/claude_agent_client.py b/tradingagents/llm_clients/claude_agent_client.py new file mode 100644 index 00000000..ccdcf83b --- /dev/null +++ b/tradingagents/llm_clients/claude_agent_client.py @@ -0,0 +1,134 @@ +"""Claude Agent SDK client — routes inference through Claude Code's OAuth session. + +Works with a Claude Max/Pro subscription; no ANTHROPIC_API_KEY required. Requires +the `claude-agent-sdk` package (bundles the Claude Code CLI). + +Shape A: supports plain .invoke() for prompt-only call sites (researchers, +managers, trader, reflection, signal processing). Tool binding raises +NotImplementedError — tool-using analysts must use a different provider until +Shape B. +""" + +import asyncio +from typing import Any, List, Optional, Tuple + +from langchain_core.messages import AIMessage, BaseMessage, SystemMessage +from langchain_core.prompt_values import PromptValue +from langchain_core.runnables import Runnable + +from .base_client import BaseLLMClient + + +# Tools the built-in Claude Code preset would otherwise enable. We disable them +# explicitly so the SDK behaves as a pure LLM for Shape A. +_DISABLED_BUILTIN_TOOLS = [ + "Bash", "Read", "Write", "Edit", "MultiEdit", + "Glob", "Grep", "WebFetch", "WebSearch", + "Task", "TodoWrite", "NotebookEdit", +] + + +def _coerce_input(input: Any) -> Tuple[Optional[str], str]: + """Collapse LangChain input into (system_prompt, user_prompt). + + The SDK takes one `prompt` string and a separate `system_prompt` option. + We fold any SystemMessage into system_prompt and concatenate the rest. + """ + if isinstance(input, str): + return None, input + + if isinstance(input, PromptValue): + input = input.to_messages() + + if not isinstance(input, list): + return None, str(input) + + system_parts: List[str] = [] + user_parts: List[str] = [] + + for msg in input: + if isinstance(msg, SystemMessage): + system_parts.append(str(msg.content)) + continue + if isinstance(msg, BaseMessage): + role = getattr(msg, "type", "human") + user_parts.append(f"[{role}] {msg.content}") + continue + if isinstance(msg, dict): + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "system": + system_parts.append(str(content)) + else: + user_parts.append(f"[{role}] {content}") + continue + user_parts.append(str(msg)) + + system_prompt = "\n\n".join(system_parts) if system_parts else None + user_prompt = "\n\n".join(user_parts) + return system_prompt, user_prompt + + +class ChatClaudeAgent(Runnable): + """LangChain-compatible Runnable that routes inference through claude-agent-sdk. + + Authenticates via Claude Code's bundled CLI session. A Claude Max/Pro + subscription satisfies auth; no API key required. + """ + + def __init__(self, model: str, **kwargs: Any) -> None: + self.model = model + self.kwargs = kwargs + + def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> AIMessage: + system_prompt, prompt = _coerce_input(input) + return asyncio.run(self._ainvoke(prompt, system_prompt)) + + async def _ainvoke(self, prompt: str, system_prompt: Optional[str]) -> AIMessage: + from claude_agent_sdk import ( + AssistantMessage, + ClaudeAgentOptions, + TextBlock, + query, + ) + + options_kwargs: dict = { + "model": self.model, + "allowed_tools": [], + "disallowed_tools": list(_DISABLED_BUILTIN_TOOLS), + "permission_mode": "default", + } + if system_prompt is not None: + options_kwargs["system_prompt"] = system_prompt + + options = ClaudeAgentOptions(**options_kwargs) + + text_parts: List[str] = [] + async for msg in query(prompt=prompt, options=options): + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + text_parts.append(block.text) + + return AIMessage(content="\n".join(text_parts)) + + def bind_tools(self, tools: Any, **kwargs: Any) -> Any: + raise NotImplementedError( + "claude_agent provider does not yet support bind_tools (Shape A). " + "Configure a different provider (anthropic, openai, etc.) for the " + "4 analysts that call bind_tools, or wait for Shape B which rewrites " + "analysts to use the SDK's native tool loop." + ) + + +class ClaudeAgentClient(BaseLLMClient): + """LLM client backed by claude-agent-sdk / Claude Code OAuth.""" + + def get_llm(self) -> Any: + self.warn_if_unknown_model() + return ChatClaudeAgent(model=self.model, **self.kwargs) + + def validate_model(self) -> bool: + # Claude Code accepts multiple model aliases (opus/sonnet/haiku, full IDs, + # short IDs); pass through and let the SDK reject unknown strings. + return True diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index a9a7e83d..1533015d 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -5,6 +5,7 @@ from .openai_client import OpenAIClient from .anthropic_client import AnthropicClient from .google_client import GoogleClient from .azure_client import AzureOpenAIClient +from .claude_agent_client import ClaudeAgentClient # Providers that use the OpenAI-compatible chat completions API _OPENAI_COMPATIBLE = ( @@ -46,4 +47,7 @@ def create_llm_client( if provider_lower == "azure": return AzureOpenAIClient(model, base_url, **kwargs) + if provider_lower == "claude_agent": + return ClaudeAgentClient(model, base_url, **kwargs) + raise ValueError(f"Unsupported LLM provider: {provider}")