diff --git a/pyproject.toml b/pyproject.toml index de27a2b9..a0444655 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "backtrader>=1.9.78.123", "langchain-anthropic>=0.3.15", "langchain-experimental>=0.3.4", + "langchain-aws>=0.2.0", "langchain-google-genai>=2.1.5", "langchain-openai>=0.3.23", "langgraph>=0.4.8", diff --git a/tradingagents/llm_clients/bedrock_client.py b/tradingagents/llm_clients/bedrock_client.py new file mode 100644 index 00000000..dd59d9a8 --- /dev/null +++ b/tradingagents/llm_clients/bedrock_client.py @@ -0,0 +1,32 @@ +from typing import Any, Optional + +from langchain_aws import ChatBedrockConverse + +from .base_client import BaseLLMClient +from .validators import validate_model + + +class BedrockClient(BaseLLMClient): + """Client for Amazon Bedrock models (Claude, Kimi, Qwen, GLM, etc.).""" + + def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + super().__init__(model, base_url, **kwargs) + + def get_llm(self) -> Any: + """Return configured ChatBedrockConverse instance.""" + llm_kwargs = {"model_id": self.model} + + if "region_name" in self.kwargs: + llm_kwargs["region_name"] = self.kwargs["region_name"] + if "max_tokens" in self.kwargs: + llm_kwargs["max_tokens"] = self.kwargs["max_tokens"] + if "callbacks" in self.kwargs: + llm_kwargs["callbacks"] = self.kwargs["callbacks"] + if "timeout" in self.kwargs: + llm_kwargs["timeout"] = self.kwargs["timeout"] + + return ChatBedrockConverse(**llm_kwargs) + + def validate_model(self) -> bool: + """Validate model for Bedrock (pass-through, model IDs are flexible).""" + return True diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 93c2a7d3..24109640 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -4,6 +4,7 @@ from .base_client import BaseLLMClient from .openai_client import OpenAIClient from .anthropic_client import AnthropicClient from .google_client import GoogleClient +from .bedrock_client import BedrockClient def create_llm_client( @@ -46,4 +47,7 @@ def create_llm_client( if provider_lower == "google": return GoogleClient(model, base_url, **kwargs) + if provider_lower == "bedrock": + return BedrockClient(model, base_url, **kwargs) + raise ValueError(f"Unsupported LLM provider: {provider}")