This commit is contained in:
Bastien 2026-04-15 09:49:51 +02:00 committed by GitHub
commit c7df42804c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 104 additions and 0 deletions

View File

@ -569,6 +569,7 @@ def get_user_selections():
thinking_level = None
reasoning_effort = None
anthropic_effort = None
bedrock_region = None
provider_lower = selected_llm_provider.lower()
if provider_lower == "google":
@ -595,6 +596,14 @@ def get_user_selections():
)
)
anthropic_effort = ask_anthropic_effort()
elif provider_lower == "bedrock":
console.print(
create_question_box(
"Step 8: AWS Region",
"Select the AWS region for Bedrock inference"
)
)
bedrock_region = ask_bedrock_region()
return {
"ticker": selected_ticker,
@ -608,6 +617,7 @@ def get_user_selections():
"google_thinking_level": thinking_level,
"openai_reasoning_effort": reasoning_effort,
"anthropic_effort": anthropic_effort,
"bedrock_region": bedrock_region,
"output_language": output_language,
}
@ -942,6 +952,7 @@ def run_analysis():
config["google_thinking_level"] = selections.get("google_thinking_level")
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
config["anthropic_effort"] = selections.get("anthropic_effort")
config["bedrock_region"] = selections.get("bedrock_region")
config["output_language"] = selections.get("output_language", "English")
# Create stats callback handler for tracking LLM/tool calls

View File

@ -240,6 +240,7 @@ def select_llm_provider() -> tuple[str, str | None]:
("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
("Amazon Bedrock", "bedrock", None),
("Azure OpenAI", "azure", None),
("Ollama", "ollama", "http://localhost:11434/v1"),
]
@ -326,6 +327,25 @@ def ask_gemini_thinking_config() -> str | None:
).ask()
def ask_bedrock_region() -> str:
"""Ask for AWS Bedrock region."""
return questionary.select(
"Select AWS Region:",
choices=[
questionary.Choice("US East 1 (N. Virginia) - default", "us-east-1"),
questionary.Choice("US West 2 (Oregon)", "us-west-2"),
questionary.Choice("EU West 1 (Ireland)", "eu-west-1"),
questionary.Choice("AP Northeast 1 (Tokyo)", "ap-northeast-1"),
questionary.Choice("AP Southeast 1 (Singapore)", "ap-southeast-1"),
],
style=questionary.Style([
("selected", "fg:cyan noinherit"),
("highlighted", "fg:cyan noinherit"),
("pointer", "fg:cyan noinherit"),
]),
).ask()
def ask_output_language() -> str:
"""Ask for report output language."""
choice = questionary.select(

View File

@ -151,6 +151,11 @@ class TradingAgentsGraph:
if effort:
kwargs["effort"] = effort
elif provider == "bedrock":
region = self.config.get("bedrock_region")
if region:
kwargs["region_name"] = region
return kwargs
def _create_tool_nodes(self) -> Dict[str, ToolNode]:

View File

@ -0,0 +1,54 @@
from typing import Any, Optional
from botocore.config import Config as BotoConfig
from langchain_aws import ChatBedrockConverse
from .base_client import BaseLLMClient, normalize_content
_BEDROCK_MODELS = [
"us.anthropic.claude-sonnet-4-6",
"us.anthropic.claude-haiku-4-5-20251001-v1:0",
"us.anthropic.claude-opus-4-6-v1",
"us.anthropic.claude-sonnet-4-5-20250929-v1:0",
]
_PASSTHROUGH_KWARGS = (
"region_name",
"credentials_profile_name",
"max_tokens",
"temperature",
"callbacks",
)
class NormalizedChatBedrockConverse(ChatBedrockConverse):
def invoke(self, input, config=None, **kwargs):
return normalize_content(super().invoke(input, config, **kwargs))
async def ainvoke(self, input, config=None, **kwargs):
return normalize_content(await super().ainvoke(input, config, **kwargs))
class BedrockClient(BaseLLMClient):
"""Client for Amazon Bedrock models via ChatBedrockConverse."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
self.warn_if_unknown_model()
llm_kwargs = {
"model_id": self.model,
"config": BotoConfig(read_timeout=300, retries={"max_attempts": 3}),
}
if self.base_url and "openai.com" not in self.base_url:
llm_kwargs["endpoint_url"] = self.base_url
for key in _PASSTHROUGH_KWARGS:
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
if "region_name" not in llm_kwargs:
llm_kwargs["region_name"] = "us-east-1"
return NormalizedChatBedrockConverse(**llm_kwargs)
def validate_model(self) -> bool:
return self.model in _BEDROCK_MODELS

View File

@ -5,6 +5,7 @@ from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
from .google_client import GoogleClient
from .azure_client import AzureOpenAIClient
from .bedrock_client import BedrockClient
# Providers that use the OpenAI-compatible chat completions API
_OPENAI_COMPATIBLE = (
@ -46,4 +47,7 @@ def create_llm_client(
if provider_lower == "azure":
return AzureOpenAIClient(model, base_url, **kwargs)
if provider_lower == "bedrock":
return BedrockClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@ -100,6 +100,16 @@ MODEL_OPTIONS: ProviderModeOptions = {
],
},
# OpenRouter: fetched dynamically. Azure: any deployed model name.
"bedrock": {
"quick": [
("Claude Haiku 4.5 via Bedrock - Fast", "us.anthropic.claude-haiku-4-5-20251001-v1:0"),
("Claude Sonnet 4.6 via Bedrock - Balanced", "us.anthropic.claude-sonnet-4-6"),
],
"deep": [
("Claude Sonnet 4.6 via Bedrock - Balanced", "us.anthropic.claude-sonnet-4-6"),
("Claude Opus 4.6 via Bedrock - Most capable", "us.anthropic.claude-opus-4-6-v1"),
],
},
"ollama": {
"quick": [
("Qwen3:latest (8B, local)", "qwen3:latest"),