diff --git a/cli/main.py b/cli/main.py index 33d110fb..088b9d45 100644 --- a/cli/main.py +++ b/cli/main.py @@ -569,6 +569,7 @@ def get_user_selections(): thinking_level = None reasoning_effort = None anthropic_effort = None + bedrock_region = None provider_lower = selected_llm_provider.lower() if provider_lower == "google": @@ -595,6 +596,14 @@ def get_user_selections(): ) ) anthropic_effort = ask_anthropic_effort() + elif provider_lower == "bedrock": + console.print( + create_question_box( + "Step 8: AWS Region", + "Select the AWS region for Bedrock inference" + ) + ) + bedrock_region = ask_bedrock_region() return { "ticker": selected_ticker, @@ -608,6 +617,7 @@ def get_user_selections(): "google_thinking_level": thinking_level, "openai_reasoning_effort": reasoning_effort, "anthropic_effort": anthropic_effort, + "bedrock_region": bedrock_region, "output_language": output_language, } @@ -942,6 +952,7 @@ def run_analysis(): config["google_thinking_level"] = selections.get("google_thinking_level") config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") config["anthropic_effort"] = selections.get("anthropic_effort") + config["bedrock_region"] = selections.get("bedrock_region") config["output_language"] = selections.get("output_language", "English") # Create stats callback handler for tracking LLM/tool calls diff --git a/cli/utils.py b/cli/utils.py index 85c282ed..9133b246 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -240,6 +240,7 @@ def select_llm_provider() -> tuple[str, str | None]: ("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"), ("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"), ("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"), + ("Amazon Bedrock", "bedrock", None), ("Azure OpenAI", "azure", None), ("Ollama", "ollama", "http://localhost:11434/v1"), ] @@ -326,6 +327,25 @@ def ask_gemini_thinking_config() -> str | None: ).ask() +def ask_bedrock_region() -> str: + """Ask for AWS Bedrock region.""" + return questionary.select( + "Select AWS Region:", + choices=[ + questionary.Choice("US East 1 (N. Virginia) - default", "us-east-1"), + questionary.Choice("US West 2 (Oregon)", "us-west-2"), + questionary.Choice("EU West 1 (Ireland)", "eu-west-1"), + questionary.Choice("AP Northeast 1 (Tokyo)", "ap-northeast-1"), + questionary.Choice("AP Southeast 1 (Singapore)", "ap-southeast-1"), + ], + style=questionary.Style([ + ("selected", "fg:cyan noinherit"), + ("highlighted", "fg:cyan noinherit"), + ("pointer", "fg:cyan noinherit"), + ]), + ).ask() + + def ask_output_language() -> str: """Ask for report output language.""" choice = questionary.select( diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 78bc13e5..083acf57 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -151,6 +151,11 @@ class TradingAgentsGraph: if effort: kwargs["effort"] = effort + elif provider == "bedrock": + region = self.config.get("bedrock_region") + if region: + kwargs["region_name"] = region + return kwargs def _create_tool_nodes(self) -> Dict[str, ToolNode]: diff --git a/tradingagents/llm_clients/bedrock_client.py b/tradingagents/llm_clients/bedrock_client.py new file mode 100644 index 00000000..d0810575 --- /dev/null +++ b/tradingagents/llm_clients/bedrock_client.py @@ -0,0 +1,54 @@ +from typing import Any, Optional + +from botocore.config import Config as BotoConfig +from langchain_aws import ChatBedrockConverse + +from .base_client import BaseLLMClient, normalize_content + +_BEDROCK_MODELS = [ + "us.anthropic.claude-sonnet-4-6", + "us.anthropic.claude-haiku-4-5-20251001-v1:0", + "us.anthropic.claude-opus-4-6-v1", + "us.anthropic.claude-sonnet-4-5-20250929-v1:0", +] + +_PASSTHROUGH_KWARGS = ( + "region_name", + "credentials_profile_name", + "max_tokens", + "temperature", + "callbacks", +) + + +class NormalizedChatBedrockConverse(ChatBedrockConverse): + def invoke(self, input, config=None, **kwargs): + return normalize_content(super().invoke(input, config, **kwargs)) + + async def ainvoke(self, input, config=None, **kwargs): + return normalize_content(await super().ainvoke(input, config, **kwargs)) + + +class BedrockClient(BaseLLMClient): + """Client for Amazon Bedrock models via ChatBedrockConverse.""" + + def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + super().__init__(model, base_url, **kwargs) + + def get_llm(self) -> Any: + self.warn_if_unknown_model() + llm_kwargs = { + "model_id": self.model, + "config": BotoConfig(read_timeout=300, retries={"max_attempts": 3}), + } + if self.base_url and "openai.com" not in self.base_url: + llm_kwargs["endpoint_url"] = self.base_url + for key in _PASSTHROUGH_KWARGS: + if key in self.kwargs: + llm_kwargs[key] = self.kwargs[key] + if "region_name" not in llm_kwargs: + llm_kwargs["region_name"] = "us-east-1" + return NormalizedChatBedrockConverse(**llm_kwargs) + + def validate_model(self) -> bool: + return self.model in _BEDROCK_MODELS diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index a9a7e83d..23f8a722 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -5,6 +5,7 @@ from .openai_client import OpenAIClient from .anthropic_client import AnthropicClient from .google_client import GoogleClient from .azure_client import AzureOpenAIClient +from .bedrock_client import BedrockClient # Providers that use the OpenAI-compatible chat completions API _OPENAI_COMPATIBLE = ( @@ -46,4 +47,7 @@ def create_llm_client( if provider_lower == "azure": return AzureOpenAIClient(model, base_url, **kwargs) + if provider_lower == "bedrock": + return BedrockClient(model, base_url, **kwargs) + raise ValueError(f"Unsupported LLM provider: {provider}") diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index a2c57ed8..375813ef 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -100,6 +100,16 @@ MODEL_OPTIONS: ProviderModeOptions = { ], }, # OpenRouter: fetched dynamically. Azure: any deployed model name. + "bedrock": { + "quick": [ + ("Claude Haiku 4.5 via Bedrock - Fast", "us.anthropic.claude-haiku-4-5-20251001-v1:0"), + ("Claude Sonnet 4.6 via Bedrock - Balanced", "us.anthropic.claude-sonnet-4-6"), + ], + "deep": [ + ("Claude Sonnet 4.6 via Bedrock - Balanced", "us.anthropic.claude-sonnet-4-6"), + ("Claude Opus 4.6 via Bedrock - Most capable", "us.anthropic.claude-opus-4-6-v1"), + ], + }, "ollama": { "quick": [ ("Qwen3:latest (8B, local)", "qwen3:latest"),