feat(bedrock): add Amazon Bedrock LLM provider support
- Add BedrockClient implementation with ChatBedrockConverse integration - Support Claude models (Haiku, Sonnet, Opus) via AWS Bedrock - Add AWS region selection in CLI with 5 region options (us-east-1, us-west-2, eu-west-1, ap-northeast-1, ap-southeast-1) - Integrate Bedrock provider into LLM factory and model catalog - Pass region_name and other Bedrock-specific parameters through trading graph configuration - Set default region to us-east-1 with configurable boto retry and timeout settings
This commit is contained in:
parent
fa4d01c23a
commit
7ffe7ea310
11
cli/main.py
11
cli/main.py
|
|
@ -569,6 +569,7 @@ def get_user_selections():
|
|||
thinking_level = None
|
||||
reasoning_effort = None
|
||||
anthropic_effort = None
|
||||
bedrock_region = None
|
||||
|
||||
provider_lower = selected_llm_provider.lower()
|
||||
if provider_lower == "google":
|
||||
|
|
@ -595,6 +596,14 @@ def get_user_selections():
|
|||
)
|
||||
)
|
||||
anthropic_effort = ask_anthropic_effort()
|
||||
elif provider_lower == "bedrock":
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 8: AWS Region",
|
||||
"Select the AWS region for Bedrock inference"
|
||||
)
|
||||
)
|
||||
bedrock_region = ask_bedrock_region()
|
||||
|
||||
return {
|
||||
"ticker": selected_ticker,
|
||||
|
|
@ -608,6 +617,7 @@ def get_user_selections():
|
|||
"google_thinking_level": thinking_level,
|
||||
"openai_reasoning_effort": reasoning_effort,
|
||||
"anthropic_effort": anthropic_effort,
|
||||
"bedrock_region": bedrock_region,
|
||||
"output_language": output_language,
|
||||
}
|
||||
|
||||
|
|
@ -942,6 +952,7 @@ def run_analysis():
|
|||
config["google_thinking_level"] = selections.get("google_thinking_level")
|
||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||
config["anthropic_effort"] = selections.get("anthropic_effort")
|
||||
config["bedrock_region"] = selections.get("bedrock_region")
|
||||
config["output_language"] = selections.get("output_language", "English")
|
||||
|
||||
# Create stats callback handler for tracking LLM/tool calls
|
||||
|
|
|
|||
20
cli/utils.py
20
cli/utils.py
|
|
@ -240,6 +240,7 @@ def select_llm_provider() -> tuple[str, str | None]:
|
|||
("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
|
||||
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Amazon Bedrock", "bedrock", None),
|
||||
("Azure OpenAI", "azure", None),
|
||||
("Ollama", "ollama", "http://localhost:11434/v1"),
|
||||
]
|
||||
|
|
@ -326,6 +327,25 @@ def ask_gemini_thinking_config() -> str | None:
|
|||
).ask()
|
||||
|
||||
|
||||
def ask_bedrock_region() -> str:
|
||||
"""Ask for AWS Bedrock region."""
|
||||
return questionary.select(
|
||||
"Select AWS Region:",
|
||||
choices=[
|
||||
questionary.Choice("US East 1 (N. Virginia) - default", "us-east-1"),
|
||||
questionary.Choice("US West 2 (Oregon)", "us-west-2"),
|
||||
questionary.Choice("EU West 1 (Ireland)", "eu-west-1"),
|
||||
questionary.Choice("AP Northeast 1 (Tokyo)", "ap-northeast-1"),
|
||||
questionary.Choice("AP Southeast 1 (Singapore)", "ap-southeast-1"),
|
||||
],
|
||||
style=questionary.Style([
|
||||
("selected", "fg:cyan noinherit"),
|
||||
("highlighted", "fg:cyan noinherit"),
|
||||
("pointer", "fg:cyan noinherit"),
|
||||
]),
|
||||
).ask()
|
||||
|
||||
|
||||
def ask_output_language() -> str:
|
||||
"""Ask for report output language."""
|
||||
choice = questionary.select(
|
||||
|
|
|
|||
|
|
@ -151,6 +151,11 @@ class TradingAgentsGraph:
|
|||
if effort:
|
||||
kwargs["effort"] = effort
|
||||
|
||||
elif provider == "bedrock":
|
||||
region = self.config.get("bedrock_region")
|
||||
if region:
|
||||
kwargs["region_name"] = region
|
||||
|
||||
return kwargs
|
||||
|
||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from botocore.config import Config as BotoConfig
|
||||
from langchain_aws import ChatBedrockConverse
|
||||
|
||||
from .base_client import BaseLLMClient, normalize_content
|
||||
|
||||
_BEDROCK_MODELS = [
|
||||
"us.anthropic.claude-sonnet-4-6",
|
||||
"us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"us.anthropic.claude-opus-4-6-v1",
|
||||
"us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
]
|
||||
|
||||
_PASSTHROUGH_KWARGS = (
|
||||
"region_name",
|
||||
"credentials_profile_name",
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"callbacks",
|
||||
)
|
||||
|
||||
|
||||
class NormalizedChatBedrockConverse(ChatBedrockConverse):
|
||||
def invoke(self, input, config=None, **kwargs):
|
||||
return normalize_content(super().invoke(input, config, **kwargs))
|
||||
|
||||
|
||||
class BedrockClient(BaseLLMClient):
|
||||
"""Client for Amazon Bedrock models via ChatBedrockConverse."""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
self.warn_if_unknown_model()
|
||||
llm_kwargs = {
|
||||
"model_id": self.model,
|
||||
"config": BotoConfig(read_timeout=300, retries={"max_attempts": 3}),
|
||||
}
|
||||
for key in _PASSTHROUGH_KWARGS:
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
if "region_name" not in llm_kwargs:
|
||||
llm_kwargs["region_name"] = "us-east-1"
|
||||
return NormalizedChatBedrockConverse(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
return self.model in _BEDROCK_MODELS
|
||||
|
|
@ -5,6 +5,7 @@ from .openai_client import OpenAIClient
|
|||
from .anthropic_client import AnthropicClient
|
||||
from .google_client import GoogleClient
|
||||
from .azure_client import AzureOpenAIClient
|
||||
from .bedrock_client import BedrockClient
|
||||
|
||||
# Providers that use the OpenAI-compatible chat completions API
|
||||
_OPENAI_COMPATIBLE = (
|
||||
|
|
@ -46,4 +47,7 @@ def create_llm_client(
|
|||
if provider_lower == "azure":
|
||||
return AzureOpenAIClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "bedrock":
|
||||
return BedrockClient(model, base_url, **kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
|
|
|||
|
|
@ -100,6 +100,16 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
|||
],
|
||||
},
|
||||
# OpenRouter: fetched dynamically. Azure: any deployed model name.
|
||||
"bedrock": {
|
||||
"quick": [
|
||||
("Claude Haiku 4.5 via Bedrock - Fast", "us.anthropic.claude-haiku-4-5-20251001-v1:0"),
|
||||
("Claude Sonnet 4.6 via Bedrock - Balanced", "us.anthropic.claude-sonnet-4-6"),
|
||||
],
|
||||
"deep": [
|
||||
("Claude Sonnet 4.6 via Bedrock - Balanced", "us.anthropic.claude-sonnet-4-6"),
|
||||
("Claude Opus 4.6 via Bedrock - Most capable", "us.anthropic.claude-opus-4-6-v1"),
|
||||
],
|
||||
},
|
||||
"ollama": {
|
||||
"quick": [
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
|
|
|
|||
Loading…
Reference in New Issue