Merge 40bf4aee99 into fa4d01c23a
This commit is contained in:
commit
c7df42804c
11
cli/main.py
11
cli/main.py
|
|
@ -569,6 +569,7 @@ def get_user_selections():
|
||||||
thinking_level = None
|
thinking_level = None
|
||||||
reasoning_effort = None
|
reasoning_effort = None
|
||||||
anthropic_effort = None
|
anthropic_effort = None
|
||||||
|
bedrock_region = None
|
||||||
|
|
||||||
provider_lower = selected_llm_provider.lower()
|
provider_lower = selected_llm_provider.lower()
|
||||||
if provider_lower == "google":
|
if provider_lower == "google":
|
||||||
|
|
@ -595,6 +596,14 @@ def get_user_selections():
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
anthropic_effort = ask_anthropic_effort()
|
anthropic_effort = ask_anthropic_effort()
|
||||||
|
elif provider_lower == "bedrock":
|
||||||
|
console.print(
|
||||||
|
create_question_box(
|
||||||
|
"Step 8: AWS Region",
|
||||||
|
"Select the AWS region for Bedrock inference"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
bedrock_region = ask_bedrock_region()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"ticker": selected_ticker,
|
"ticker": selected_ticker,
|
||||||
|
|
@ -608,6 +617,7 @@ def get_user_selections():
|
||||||
"google_thinking_level": thinking_level,
|
"google_thinking_level": thinking_level,
|
||||||
"openai_reasoning_effort": reasoning_effort,
|
"openai_reasoning_effort": reasoning_effort,
|
||||||
"anthropic_effort": anthropic_effort,
|
"anthropic_effort": anthropic_effort,
|
||||||
|
"bedrock_region": bedrock_region,
|
||||||
"output_language": output_language,
|
"output_language": output_language,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -942,6 +952,7 @@ def run_analysis():
|
||||||
config["google_thinking_level"] = selections.get("google_thinking_level")
|
config["google_thinking_level"] = selections.get("google_thinking_level")
|
||||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||||
config["anthropic_effort"] = selections.get("anthropic_effort")
|
config["anthropic_effort"] = selections.get("anthropic_effort")
|
||||||
|
config["bedrock_region"] = selections.get("bedrock_region")
|
||||||
config["output_language"] = selections.get("output_language", "English")
|
config["output_language"] = selections.get("output_language", "English")
|
||||||
|
|
||||||
# Create stats callback handler for tracking LLM/tool calls
|
# Create stats callback handler for tracking LLM/tool calls
|
||||||
|
|
|
||||||
20
cli/utils.py
20
cli/utils.py
|
|
@ -240,6 +240,7 @@ def select_llm_provider() -> tuple[str, str | None]:
|
||||||
("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||||
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
|
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
|
||||||
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
||||||
|
("Amazon Bedrock", "bedrock", None),
|
||||||
("Azure OpenAI", "azure", None),
|
("Azure OpenAI", "azure", None),
|
||||||
("Ollama", "ollama", "http://localhost:11434/v1"),
|
("Ollama", "ollama", "http://localhost:11434/v1"),
|
||||||
]
|
]
|
||||||
|
|
@ -326,6 +327,25 @@ def ask_gemini_thinking_config() -> str | None:
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def ask_bedrock_region() -> str:
|
||||||
|
"""Ask for AWS Bedrock region."""
|
||||||
|
return questionary.select(
|
||||||
|
"Select AWS Region:",
|
||||||
|
choices=[
|
||||||
|
questionary.Choice("US East 1 (N. Virginia) - default", "us-east-1"),
|
||||||
|
questionary.Choice("US West 2 (Oregon)", "us-west-2"),
|
||||||
|
questionary.Choice("EU West 1 (Ireland)", "eu-west-1"),
|
||||||
|
questionary.Choice("AP Northeast 1 (Tokyo)", "ap-northeast-1"),
|
||||||
|
questionary.Choice("AP Southeast 1 (Singapore)", "ap-southeast-1"),
|
||||||
|
],
|
||||||
|
style=questionary.Style([
|
||||||
|
("selected", "fg:cyan noinherit"),
|
||||||
|
("highlighted", "fg:cyan noinherit"),
|
||||||
|
("pointer", "fg:cyan noinherit"),
|
||||||
|
]),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
def ask_output_language() -> str:
|
def ask_output_language() -> str:
|
||||||
"""Ask for report output language."""
|
"""Ask for report output language."""
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
|
|
|
||||||
|
|
@ -151,6 +151,11 @@ class TradingAgentsGraph:
|
||||||
if effort:
|
if effort:
|
||||||
kwargs["effort"] = effort
|
kwargs["effort"] = effort
|
||||||
|
|
||||||
|
elif provider == "bedrock":
|
||||||
|
region = self.config.get("bedrock_region")
|
||||||
|
if region:
|
||||||
|
kwargs["region_name"] = region
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from botocore.config import Config as BotoConfig
|
||||||
|
from langchain_aws import ChatBedrockConverse
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
|
|
||||||
|
_BEDROCK_MODELS = [
|
||||||
|
"us.anthropic.claude-sonnet-4-6",
|
||||||
|
"us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||||
|
"us.anthropic.claude-opus-4-6-v1",
|
||||||
|
"us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
]
|
||||||
|
|
||||||
|
_PASSTHROUGH_KWARGS = (
|
||||||
|
"region_name",
|
||||||
|
"credentials_profile_name",
|
||||||
|
"max_tokens",
|
||||||
|
"temperature",
|
||||||
|
"callbacks",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizedChatBedrockConverse(ChatBedrockConverse):
|
||||||
|
def invoke(self, input, config=None, **kwargs):
|
||||||
|
return normalize_content(super().invoke(input, config, **kwargs))
|
||||||
|
|
||||||
|
async def ainvoke(self, input, config=None, **kwargs):
|
||||||
|
return normalize_content(await super().ainvoke(input, config, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockClient(BaseLLMClient):
|
||||||
|
"""Client for Amazon Bedrock models via ChatBedrockConverse."""
|
||||||
|
|
||||||
|
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
self.warn_if_unknown_model()
|
||||||
|
llm_kwargs = {
|
||||||
|
"model_id": self.model,
|
||||||
|
"config": BotoConfig(read_timeout=300, retries={"max_attempts": 3}),
|
||||||
|
}
|
||||||
|
if self.base_url and "openai.com" not in self.base_url:
|
||||||
|
llm_kwargs["endpoint_url"] = self.base_url
|
||||||
|
for key in _PASSTHROUGH_KWARGS:
|
||||||
|
if key in self.kwargs:
|
||||||
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
if "region_name" not in llm_kwargs:
|
||||||
|
llm_kwargs["region_name"] = "us-east-1"
|
||||||
|
return NormalizedChatBedrockConverse(**llm_kwargs)
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
return self.model in _BEDROCK_MODELS
|
||||||
|
|
@ -5,6 +5,7 @@ from .openai_client import OpenAIClient
|
||||||
from .anthropic_client import AnthropicClient
|
from .anthropic_client import AnthropicClient
|
||||||
from .google_client import GoogleClient
|
from .google_client import GoogleClient
|
||||||
from .azure_client import AzureOpenAIClient
|
from .azure_client import AzureOpenAIClient
|
||||||
|
from .bedrock_client import BedrockClient
|
||||||
|
|
||||||
# Providers that use the OpenAI-compatible chat completions API
|
# Providers that use the OpenAI-compatible chat completions API
|
||||||
_OPENAI_COMPATIBLE = (
|
_OPENAI_COMPATIBLE = (
|
||||||
|
|
@ -46,4 +47,7 @@ def create_llm_client(
|
||||||
if provider_lower == "azure":
|
if provider_lower == "azure":
|
||||||
return AzureOpenAIClient(model, base_url, **kwargs)
|
return AzureOpenAIClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
if provider_lower == "bedrock":
|
||||||
|
return BedrockClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -100,6 +100,16 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
# OpenRouter: fetched dynamically. Azure: any deployed model name.
|
# OpenRouter: fetched dynamically. Azure: any deployed model name.
|
||||||
|
"bedrock": {
|
||||||
|
"quick": [
|
||||||
|
("Claude Haiku 4.5 via Bedrock - Fast", "us.anthropic.claude-haiku-4-5-20251001-v1:0"),
|
||||||
|
("Claude Sonnet 4.6 via Bedrock - Balanced", "us.anthropic.claude-sonnet-4-6"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("Claude Sonnet 4.6 via Bedrock - Balanced", "us.anthropic.claude-sonnet-4-6"),
|
||||||
|
("Claude Opus 4.6 via Bedrock - Most capable", "us.anthropic.claude-opus-4-6-v1"),
|
||||||
|
],
|
||||||
|
},
|
||||||
"ollama": {
|
"ollama": {
|
||||||
"quick": [
|
"quick": [
|
||||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue