feat: add claude_agent LLM provider backed by Claude Agent SDK (Shape A)
Routes inference through Claude Code's OAuth session, so a Claude Max/Pro subscription authenticates without an Anthropic API key. Shape A supports plain .invoke() for prompt-only call sites (researchers, managers, trader, reflection, signal processing); bind_tools raises NotImplementedError until Shape B rewrites analysts to use the SDK's native tool loop. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
fa4d01c23a
commit
9a0bbb8294
|
|
@ -0,0 +1,66 @@
|
||||||
|
"""Smoke test for the claude_agent provider (Shape A).
|
||||||
|
|
||||||
|
Exercises the three call patterns used in TradingAgents:
|
||||||
|
1. Plain string prompt (trader, researchers)
|
||||||
|
2. List of role/content dicts (trader)
|
||||||
|
3. List of LangChain messages via ChatPromptTemplate (simulated manager path)
|
||||||
|
|
||||||
|
Prints the response text and timing for each. Any exception = failure.
|
||||||
|
Requires the user to be logged into Claude Code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.factory import create_llm_client
|
||||||
|
|
||||||
|
|
||||||
|
def run(label, invocation):
|
||||||
|
start = time.monotonic()
|
||||||
|
result = invocation()
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
content = result.content if hasattr(result, "content") else result
|
||||||
|
preview = (content[:300] + "…") if len(content) > 300 else content
|
||||||
|
print(f"\n--- {label} ({elapsed:.1f}s) ---")
|
||||||
|
print(preview)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
client = create_llm_client(
|
||||||
|
provider="claude_agent",
|
||||||
|
model="sonnet",
|
||||||
|
)
|
||||||
|
llm = client.get_llm()
|
||||||
|
|
||||||
|
run("1. string prompt",
|
||||||
|
lambda: llm.invoke("In one sentence, what is a P/E ratio?"))
|
||||||
|
|
||||||
|
run("2. role/content dicts",
|
||||||
|
lambda: llm.invoke([
|
||||||
|
{"role": "system", "content": "You are a concise financial analyst."},
|
||||||
|
{"role": "user", "content": "In one sentence, define moving average."},
|
||||||
|
]))
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_messages([
|
||||||
|
("system", "You are a concise financial analyst."),
|
||||||
|
MessagesPlaceholder(variable_name="messages"),
|
||||||
|
])
|
||||||
|
run("3. ChatPromptTemplate pipe",
|
||||||
|
lambda: (prompt | llm).invoke({
|
||||||
|
"messages": [HumanMessage(content="In one sentence, define RSI.")]
|
||||||
|
}))
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm.bind_tools([])
|
||||||
|
except NotImplementedError as e:
|
||||||
|
print(f"\n--- 4. bind_tools guard ---\nRaised as expected: {e}")
|
||||||
|
else:
|
||||||
|
raise AssertionError("bind_tools should have raised NotImplementedError")
|
||||||
|
|
||||||
|
print("\nSmoke test OK.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,134 @@
|
||||||
|
"""Claude Agent SDK client — routes inference through Claude Code's OAuth session.
|
||||||
|
|
||||||
|
Works with a Claude Max/Pro subscription; no ANTHROPIC_API_KEY required. Requires
|
||||||
|
the `claude-agent-sdk` package (bundles the Claude Code CLI).
|
||||||
|
|
||||||
|
Shape A: supports plain .invoke() for prompt-only call sites (researchers,
|
||||||
|
managers, trader, reflection, signal processing). Tool binding raises
|
||||||
|
NotImplementedError — tool-using analysts must use a different provider until
|
||||||
|
Shape B.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
|
||||||
|
from langchain_core.prompt_values import PromptValue
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient
|
||||||
|
|
||||||
|
|
||||||
|
# Tools the built-in Claude Code preset would otherwise enable. We disable them
|
||||||
|
# explicitly so the SDK behaves as a pure LLM for Shape A.
|
||||||
|
_DISABLED_BUILTIN_TOOLS = [
|
||||||
|
"Bash", "Read", "Write", "Edit", "MultiEdit",
|
||||||
|
"Glob", "Grep", "WebFetch", "WebSearch",
|
||||||
|
"Task", "TodoWrite", "NotebookEdit",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_input(input: Any) -> Tuple[Optional[str], str]:
|
||||||
|
"""Collapse LangChain input into (system_prompt, user_prompt).
|
||||||
|
|
||||||
|
The SDK takes one `prompt` string and a separate `system_prompt` option.
|
||||||
|
We fold any SystemMessage into system_prompt and concatenate the rest.
|
||||||
|
"""
|
||||||
|
if isinstance(input, str):
|
||||||
|
return None, input
|
||||||
|
|
||||||
|
if isinstance(input, PromptValue):
|
||||||
|
input = input.to_messages()
|
||||||
|
|
||||||
|
if not isinstance(input, list):
|
||||||
|
return None, str(input)
|
||||||
|
|
||||||
|
system_parts: List[str] = []
|
||||||
|
user_parts: List[str] = []
|
||||||
|
|
||||||
|
for msg in input:
|
||||||
|
if isinstance(msg, SystemMessage):
|
||||||
|
system_parts.append(str(msg.content))
|
||||||
|
continue
|
||||||
|
if isinstance(msg, BaseMessage):
|
||||||
|
role = getattr(msg, "type", "human")
|
||||||
|
user_parts.append(f"[{role}] {msg.content}")
|
||||||
|
continue
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
role = msg.get("role", "user")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if role == "system":
|
||||||
|
system_parts.append(str(content))
|
||||||
|
else:
|
||||||
|
user_parts.append(f"[{role}] {content}")
|
||||||
|
continue
|
||||||
|
user_parts.append(str(msg))
|
||||||
|
|
||||||
|
system_prompt = "\n\n".join(system_parts) if system_parts else None
|
||||||
|
user_prompt = "\n\n".join(user_parts)
|
||||||
|
return system_prompt, user_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class ChatClaudeAgent(Runnable):
|
||||||
|
"""LangChain-compatible Runnable that routes inference through claude-agent-sdk.
|
||||||
|
|
||||||
|
Authenticates via Claude Code's bundled CLI session. A Claude Max/Pro
|
||||||
|
subscription satisfies auth; no API key required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: str, **kwargs: Any) -> None:
|
||||||
|
self.model = model
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> AIMessage:
|
||||||
|
system_prompt, prompt = _coerce_input(input)
|
||||||
|
return asyncio.run(self._ainvoke(prompt, system_prompt))
|
||||||
|
|
||||||
|
async def _ainvoke(self, prompt: str, system_prompt: Optional[str]) -> AIMessage:
|
||||||
|
from claude_agent_sdk import (
|
||||||
|
AssistantMessage,
|
||||||
|
ClaudeAgentOptions,
|
||||||
|
TextBlock,
|
||||||
|
query,
|
||||||
|
)
|
||||||
|
|
||||||
|
options_kwargs: dict = {
|
||||||
|
"model": self.model,
|
||||||
|
"allowed_tools": [],
|
||||||
|
"disallowed_tools": list(_DISABLED_BUILTIN_TOOLS),
|
||||||
|
"permission_mode": "default",
|
||||||
|
}
|
||||||
|
if system_prompt is not None:
|
||||||
|
options_kwargs["system_prompt"] = system_prompt
|
||||||
|
|
||||||
|
options = ClaudeAgentOptions(**options_kwargs)
|
||||||
|
|
||||||
|
text_parts: List[str] = []
|
||||||
|
async for msg in query(prompt=prompt, options=options):
|
||||||
|
if isinstance(msg, AssistantMessage):
|
||||||
|
for block in msg.content:
|
||||||
|
if isinstance(block, TextBlock):
|
||||||
|
text_parts.append(block.text)
|
||||||
|
|
||||||
|
return AIMessage(content="\n".join(text_parts))
|
||||||
|
|
||||||
|
def bind_tools(self, tools: Any, **kwargs: Any) -> Any:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"claude_agent provider does not yet support bind_tools (Shape A). "
|
||||||
|
"Configure a different provider (anthropic, openai, etc.) for the "
|
||||||
|
"4 analysts that call bind_tools, or wait for Shape B which rewrites "
|
||||||
|
"analysts to use the SDK's native tool loop."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeAgentClient(BaseLLMClient):
|
||||||
|
"""LLM client backed by claude-agent-sdk / Claude Code OAuth."""
|
||||||
|
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
self.warn_if_unknown_model()
|
||||||
|
return ChatClaudeAgent(model=self.model, **self.kwargs)
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
# Claude Code accepts multiple model aliases (opus/sonnet/haiku, full IDs,
|
||||||
|
# short IDs); pass through and let the SDK reject unknown strings.
|
||||||
|
return True
|
||||||
|
|
@ -5,6 +5,7 @@ from .openai_client import OpenAIClient
|
||||||
from .anthropic_client import AnthropicClient
|
from .anthropic_client import AnthropicClient
|
||||||
from .google_client import GoogleClient
|
from .google_client import GoogleClient
|
||||||
from .azure_client import AzureOpenAIClient
|
from .azure_client import AzureOpenAIClient
|
||||||
|
from .claude_agent_client import ClaudeAgentClient
|
||||||
|
|
||||||
# Providers that use the OpenAI-compatible chat completions API
|
# Providers that use the OpenAI-compatible chat completions API
|
||||||
_OPENAI_COMPATIBLE = (
|
_OPENAI_COMPATIBLE = (
|
||||||
|
|
@ -46,4 +47,7 @@ def create_llm_client(
|
||||||
if provider_lower == "azure":
|
if provider_lower == "azure":
|
||||||
return AzureOpenAIClient(model, base_url, **kwargs)
|
return AzureOpenAIClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
if provider_lower == "claude_agent":
|
||||||
|
return ClaudeAgentClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue