feat: add claude_agent LLM provider backed by Claude Agent SDK (Shape A)
Routes inference through Claude Code's OAuth session, so a Claude Max/Pro subscription authenticates without an Anthropic API key. Shape A supports plain .invoke() for prompt-only call sites (researchers, managers, trader, reflection, signal processing); bind_tools raises NotImplementedError until Shape B rewrites analysts to use the SDK's native tool loop. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
fa4d01c23a
commit
9a0bbb8294
|
|
@ -0,0 +1,66 @@
|
|||
"""Smoke test for the claude_agent provider (Shape A).
|
||||
|
||||
Exercises the three call patterns used in TradingAgents:
|
||||
1. Plain string prompt (trader, researchers)
|
||||
2. List of role/content dicts (trader)
|
||||
3. List of LangChain messages via ChatPromptTemplate (simulated manager path)
|
||||
|
||||
Prints the response text and timing for each. Any exception = failure.
|
||||
Requires the user to be logged into Claude Code.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
from tradingagents.llm_clients.factory import create_llm_client
|
||||
|
||||
|
||||
def run(label, invocation):
|
||||
start = time.monotonic()
|
||||
result = invocation()
|
||||
elapsed = time.monotonic() - start
|
||||
content = result.content if hasattr(result, "content") else result
|
||||
preview = (content[:300] + "…") if len(content) > 300 else content
|
||||
print(f"\n--- {label} ({elapsed:.1f}s) ---")
|
||||
print(preview)
|
||||
|
||||
|
||||
def main():
|
||||
client = create_llm_client(
|
||||
provider="claude_agent",
|
||||
model="sonnet",
|
||||
)
|
||||
llm = client.get_llm()
|
||||
|
||||
run("1. string prompt",
|
||||
lambda: llm.invoke("In one sentence, what is a P/E ratio?"))
|
||||
|
||||
run("2. role/content dicts",
|
||||
lambda: llm.invoke([
|
||||
{"role": "system", "content": "You are a concise financial analyst."},
|
||||
{"role": "user", "content": "In one sentence, define moving average."},
|
||||
]))
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "You are a concise financial analyst."),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
])
|
||||
run("3. ChatPromptTemplate pipe",
|
||||
lambda: (prompt | llm).invoke({
|
||||
"messages": [HumanMessage(content="In one sentence, define RSI.")]
|
||||
}))
|
||||
|
||||
try:
|
||||
llm.bind_tools([])
|
||||
except NotImplementedError as e:
|
||||
print(f"\n--- 4. bind_tools guard ---\nRaised as expected: {e}")
|
||||
else:
|
||||
raise AssertionError("bind_tools should have raised NotImplementedError")
|
||||
|
||||
print("\nSmoke test OK.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
"""Claude Agent SDK client — routes inference through Claude Code's OAuth session.
|
||||
|
||||
Works with a Claude Max/Pro subscription; no ANTHROPIC_API_KEY required. Requires
|
||||
the `claude-agent-sdk` package (bundles the Claude Code CLI).
|
||||
|
||||
Shape A: supports plain .invoke() for prompt-only call sites (researchers,
|
||||
managers, trader, reflection, signal processing). Tool binding raises
|
||||
NotImplementedError — tool-using analysts must use a different provider until
|
||||
Shape B.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
|
||||
|
||||
# Tools the built-in Claude Code preset would otherwise enable. We disable them
|
||||
# explicitly so the SDK behaves as a pure LLM for Shape A.
|
||||
_DISABLED_BUILTIN_TOOLS = [
|
||||
"Bash", "Read", "Write", "Edit", "MultiEdit",
|
||||
"Glob", "Grep", "WebFetch", "WebSearch",
|
||||
"Task", "TodoWrite", "NotebookEdit",
|
||||
]
|
||||
|
||||
|
||||
def _coerce_input(input: Any) -> Tuple[Optional[str], str]:
|
||||
"""Collapse LangChain input into (system_prompt, user_prompt).
|
||||
|
||||
The SDK takes one `prompt` string and a separate `system_prompt` option.
|
||||
We fold any SystemMessage into system_prompt and concatenate the rest.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
return None, input
|
||||
|
||||
if isinstance(input, PromptValue):
|
||||
input = input.to_messages()
|
||||
|
||||
if not isinstance(input, list):
|
||||
return None, str(input)
|
||||
|
||||
system_parts: List[str] = []
|
||||
user_parts: List[str] = []
|
||||
|
||||
for msg in input:
|
||||
if isinstance(msg, SystemMessage):
|
||||
system_parts.append(str(msg.content))
|
||||
continue
|
||||
if isinstance(msg, BaseMessage):
|
||||
role = getattr(msg, "type", "human")
|
||||
user_parts.append(f"[{role}] {msg.content}")
|
||||
continue
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
if role == "system":
|
||||
system_parts.append(str(content))
|
||||
else:
|
||||
user_parts.append(f"[{role}] {content}")
|
||||
continue
|
||||
user_parts.append(str(msg))
|
||||
|
||||
system_prompt = "\n\n".join(system_parts) if system_parts else None
|
||||
user_prompt = "\n\n".join(user_parts)
|
||||
return system_prompt, user_prompt
|
||||
|
||||
|
||||
class ChatClaudeAgent(Runnable):
|
||||
"""LangChain-compatible Runnable that routes inference through claude-agent-sdk.
|
||||
|
||||
Authenticates via Claude Code's bundled CLI session. A Claude Max/Pro
|
||||
subscription satisfies auth; no API key required.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs: Any) -> None:
|
||||
self.model = model
|
||||
self.kwargs = kwargs
|
||||
|
||||
def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> AIMessage:
|
||||
system_prompt, prompt = _coerce_input(input)
|
||||
return asyncio.run(self._ainvoke(prompt, system_prompt))
|
||||
|
||||
async def _ainvoke(self, prompt: str, system_prompt: Optional[str]) -> AIMessage:
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ClaudeAgentOptions,
|
||||
TextBlock,
|
||||
query,
|
||||
)
|
||||
|
||||
options_kwargs: dict = {
|
||||
"model": self.model,
|
||||
"allowed_tools": [],
|
||||
"disallowed_tools": list(_DISABLED_BUILTIN_TOOLS),
|
||||
"permission_mode": "default",
|
||||
}
|
||||
if system_prompt is not None:
|
||||
options_kwargs["system_prompt"] = system_prompt
|
||||
|
||||
options = ClaudeAgentOptions(**options_kwargs)
|
||||
|
||||
text_parts: List[str] = []
|
||||
async for msg in query(prompt=prompt, options=options):
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
text_parts.append(block.text)
|
||||
|
||||
return AIMessage(content="\n".join(text_parts))
|
||||
|
||||
def bind_tools(self, tools: Any, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError(
|
||||
"claude_agent provider does not yet support bind_tools (Shape A). "
|
||||
"Configure a different provider (anthropic, openai, etc.) for the "
|
||||
"4 analysts that call bind_tools, or wait for Shape B which rewrites "
|
||||
"analysts to use the SDK's native tool loop."
|
||||
)
|
||||
|
||||
|
||||
class ClaudeAgentClient(BaseLLMClient):
|
||||
"""LLM client backed by claude-agent-sdk / Claude Code OAuth."""
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
self.warn_if_unknown_model()
|
||||
return ChatClaudeAgent(model=self.model, **self.kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
# Claude Code accepts multiple model aliases (opus/sonnet/haiku, full IDs,
|
||||
# short IDs); pass through and let the SDK reject unknown strings.
|
||||
return True
|
||||
|
|
@ -5,6 +5,7 @@ from .openai_client import OpenAIClient
|
|||
from .anthropic_client import AnthropicClient
|
||||
from .google_client import GoogleClient
|
||||
from .azure_client import AzureOpenAIClient
|
||||
from .claude_agent_client import ClaudeAgentClient
|
||||
|
||||
# Providers that use the OpenAI-compatible chat completions API
|
||||
_OPENAI_COMPATIBLE = (
|
||||
|
|
@ -46,4 +47,7 @@ def create_llm_client(
|
|||
if provider_lower == "azure":
|
||||
return AzureOpenAIClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "claude_agent":
|
||||
return ClaudeAgentClient(model, base_url, **kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue