Compare commits
5 Commits
0e242d33d2
...
28f2384990
| Author | SHA1 | Date |
|---|---|---|
|
|
28f2384990 | |
|
|
9ba1858948 | |
|
|
fa4d01c23a | |
|
|
b0f6058299 | |
|
|
59d6b2152d |
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Azure OpenAI
|
||||||
|
AZURE_OPENAI_API_KEY=
|
||||||
|
AZURE_OPENAI_ENDPOINT=https://your-resource-name.openai.azure.com/
|
||||||
|
AZURE_OPENAI_DEPLOYMENT_NAME=
|
||||||
|
# OPENAI_API_VERSION=2024-10-21 # optional, required for non-v1 API
|
||||||
|
|
@ -3,4 +3,7 @@ OPENAI_API_KEY=
|
||||||
GOOGLE_API_KEY=
|
GOOGLE_API_KEY=
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
XAI_API_KEY=
|
XAI_API_KEY=
|
||||||
|
DEEPSEEK_API_KEY=
|
||||||
|
DASHSCOPE_API_KEY=
|
||||||
|
ZHIPU_API_KEY=
|
||||||
OPENROUTER_API_KEY=
|
OPENROUTER_API_KEY=
|
||||||
|
|
|
||||||
|
|
@ -140,10 +140,15 @@ export OPENAI_API_KEY=... # OpenAI (GPT)
|
||||||
export GOOGLE_API_KEY=... # Google (Gemini)
|
export GOOGLE_API_KEY=... # Google (Gemini)
|
||||||
export ANTHROPIC_API_KEY=... # Anthropic (Claude)
|
export ANTHROPIC_API_KEY=... # Anthropic (Claude)
|
||||||
export XAI_API_KEY=... # xAI (Grok)
|
export XAI_API_KEY=... # xAI (Grok)
|
||||||
|
export DEEPSEEK_API_KEY=... # DeepSeek
|
||||||
|
export DASHSCOPE_API_KEY=... # Qwen (Alibaba DashScope)
|
||||||
|
export ZHIPU_API_KEY=... # GLM (Zhipu)
|
||||||
export OPENROUTER_API_KEY=... # OpenRouter
|
export OPENROUTER_API_KEY=... # OpenRouter
|
||||||
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials.
|
||||||
|
|
||||||
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
|
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
|
||||||
|
|
||||||
Alternatively, copy `.env.example` to `.env` and fill in your keys:
|
Alternatively, copy `.env.example` to `.env` and fill in your keys:
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,41 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_call_signature(tool_call):
|
||||||
|
if isinstance(tool_call, dict):
|
||||||
|
name = tool_call["name"]
|
||||||
|
args = tool_call["args"]
|
||||||
|
else:
|
||||||
|
name = tool_call.name
|
||||||
|
args = tool_call.args
|
||||||
|
return (name, json.dumps(args, sort_keys=True, default=str))
|
||||||
|
|
||||||
|
|
||||||
|
def _message_fingerprint(message, msg_type, content):
|
||||||
|
tool_calls = tuple(_tool_call_signature(tool_call) for tool_call in getattr(message, "tool_calls", []) or [])
|
||||||
|
return (
|
||||||
|
message.__class__.__name__,
|
||||||
|
msg_type,
|
||||||
|
content.strip() if isinstance(content, str) else str(content),
|
||||||
|
tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def ingest_chunk_messages(message_buffer, chunk, classify_message_type) -> None:
|
def ingest_chunk_messages(message_buffer, chunk, classify_message_type) -> None:
|
||||||
"""Ingest all newly seen messages from a graph stream chunk."""
|
"""Ingest all newly seen messages from a graph stream chunk."""
|
||||||
for message in chunk.get("messages", []):
|
for message in chunk.get("messages", []):
|
||||||
|
msg_type, content = classify_message_type(message)
|
||||||
msg_id = getattr(message, "id", None)
|
msg_id = getattr(message, "id", None)
|
||||||
if msg_id is not None:
|
if msg_id is not None:
|
||||||
if msg_id in message_buffer._processed_message_ids:
|
if msg_id in message_buffer._processed_message_ids:
|
||||||
continue
|
continue
|
||||||
message_buffer._processed_message_ids.add(msg_id)
|
message_buffer._processed_message_ids.add(msg_id)
|
||||||
|
else:
|
||||||
|
fingerprint = _message_fingerprint(message, msg_type, content)
|
||||||
|
if fingerprint in message_buffer._processed_message_fingerprints:
|
||||||
|
continue
|
||||||
|
message_buffer._processed_message_fingerprints.add(fingerprint)
|
||||||
|
|
||||||
msg_type, content = classify_message_type(message)
|
|
||||||
if content and content.strip():
|
if content and content.strip():
|
||||||
message_buffer.add_message(msg_type, content)
|
message_buffer.add_message(msg_type, content)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,9 @@ from functools import wraps
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# Load environment variables from .env file
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
load_dotenv(".env.enterprise", override=False)
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.spinner import Spinner
|
from rich.spinner import Spinner
|
||||||
from rich.live import Live
|
from rich.live import Live
|
||||||
|
|
@ -81,6 +82,7 @@ class MessageBuffer:
|
||||||
self.report_sections = {}
|
self.report_sections = {}
|
||||||
self.selected_analysts = []
|
self.selected_analysts = []
|
||||||
self._processed_message_ids = set()
|
self._processed_message_ids = set()
|
||||||
|
self._processed_message_fingerprints = set()
|
||||||
|
|
||||||
def init_for_analysis(self, selected_analysts):
|
def init_for_analysis(self, selected_analysts):
|
||||||
"""Initialize agent status and report sections based on selected analysts.
|
"""Initialize agent status and report sections based on selected analysts.
|
||||||
|
|
@ -116,6 +118,7 @@ class MessageBuffer:
|
||||||
self.messages.clear()
|
self.messages.clear()
|
||||||
self.tool_calls.clear()
|
self.tool_calls.clear()
|
||||||
self._processed_message_ids.clear()
|
self._processed_message_ids.clear()
|
||||||
|
self._processed_message_fingerprints.clear()
|
||||||
|
|
||||||
def get_completed_reports_count(self):
|
def get_completed_reports_count(self):
|
||||||
"""Count reports that are finalized (their finalizing agent is completed).
|
"""Count reports that are finalized (their finalizing agent is completed).
|
||||||
|
|
|
||||||
92
cli/utils.py
92
cli/utils.py
|
|
@ -174,17 +174,30 @@ def select_openrouter_model() -> str:
|
||||||
return choice
|
return choice
|
||||||
|
|
||||||
|
|
||||||
def select_shallow_thinking_agent(provider) -> str:
|
def _prompt_custom_model_id() -> str:
|
||||||
"""Select shallow thinking llm engine using an interactive selection."""
|
"""Prompt user to type a custom model ID."""
|
||||||
|
return questionary.text(
|
||||||
|
"Enter model ID:",
|
||||||
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
|
||||||
|
).ask().strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _select_model(provider: str, mode: str) -> str:
|
||||||
|
"""Select a model for the given provider and mode (quick/deep)."""
|
||||||
if provider.lower() == "openrouter":
|
if provider.lower() == "openrouter":
|
||||||
return select_openrouter_model()
|
return select_openrouter_model()
|
||||||
|
|
||||||
|
if provider.lower() == "azure":
|
||||||
|
return questionary.text(
|
||||||
|
f"Enter Azure deployment name ({mode}-thinking):",
|
||||||
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a deployment name.",
|
||||||
|
).ask().strip()
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
"Select Your [Quick-Thinking LLM Engine]:",
|
f"Select Your [{mode.title()}-Thinking LLM Engine]:",
|
||||||
choices=[
|
choices=[
|
||||||
questionary.Choice(display, value=value)
|
questionary.Choice(display, value=value)
|
||||||
for display, value in get_model_options(provider, "quick")
|
for display, value in get_model_options(provider, mode)
|
||||||
],
|
],
|
||||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||||
style=questionary.Style(
|
style=questionary.Style(
|
||||||
|
|
@ -197,58 +210,45 @@ def select_shallow_thinking_agent(provider) -> str:
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
if choice is None:
|
||||||
console.print(
|
console.print(f"\n[red]No {mode} thinking llm engine selected. Exiting...[/red]")
|
||||||
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
|
|
||||||
)
|
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
if choice == "custom":
|
||||||
|
return _prompt_custom_model_id()
|
||||||
|
|
||||||
return choice
|
return choice
|
||||||
|
|
||||||
|
|
||||||
|
def select_shallow_thinking_agent(provider) -> str:
|
||||||
|
"""Select shallow thinking llm engine using an interactive selection."""
|
||||||
|
return _select_model(provider, "quick")
|
||||||
|
|
||||||
|
|
||||||
def select_deep_thinking_agent(provider) -> str:
|
def select_deep_thinking_agent(provider) -> str:
|
||||||
"""Select deep thinking llm engine using an interactive selection."""
|
"""Select deep thinking llm engine using an interactive selection."""
|
||||||
|
return _select_model(provider, "deep")
|
||||||
if provider.lower() == "openrouter":
|
|
||||||
return select_openrouter_model()
|
|
||||||
|
|
||||||
choice = questionary.select(
|
|
||||||
"Select Your [Deep-Thinking LLM Engine]:",
|
|
||||||
choices=[
|
|
||||||
questionary.Choice(display, value=value)
|
|
||||||
for display, value in get_model_options(provider, "deep")
|
|
||||||
],
|
|
||||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
|
||||||
style=questionary.Style(
|
|
||||||
[
|
|
||||||
("selected", "fg:magenta noinherit"),
|
|
||||||
("highlighted", "fg:magenta noinherit"),
|
|
||||||
("pointer", "fg:magenta noinherit"),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
).ask()
|
|
||||||
|
|
||||||
if choice is None:
|
|
||||||
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
return choice
|
|
||||||
|
|
||||||
def select_llm_provider() -> tuple[str, str | None]:
|
def select_llm_provider() -> tuple[str, str | None]:
|
||||||
"""Select the LLM provider and its API endpoint."""
|
"""Select the LLM provider and its API endpoint."""
|
||||||
BASE_URLS = [
|
# (display_name, provider_key, base_url)
|
||||||
("OpenAI", "https://api.openai.com/v1"),
|
PROVIDERS = [
|
||||||
("Google", None), # google-genai SDK manages its own endpoint
|
("OpenAI", "openai", "https://api.openai.com/v1"),
|
||||||
("Anthropic", "https://api.anthropic.com/"),
|
("Google", "google", None),
|
||||||
("xAI", "https://api.x.ai/v1"),
|
("Anthropic", "anthropic", "https://api.anthropic.com/"),
|
||||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
("xAI", "xai", "https://api.x.ai/v1"),
|
||||||
("Ollama", "http://localhost:11434/v1"),
|
("DeepSeek", "deepseek", "https://api.deepseek.com"),
|
||||||
|
("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||||
|
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
|
||||||
|
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
||||||
|
("Azure OpenAI", "azure", None),
|
||||||
|
("Ollama", "ollama", "http://localhost:11434/v1"),
|
||||||
]
|
]
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
"Select your LLM Provider:",
|
"Select your LLM Provider:",
|
||||||
choices=[
|
choices=[
|
||||||
questionary.Choice(display, value=(display, value))
|
questionary.Choice(display, value=(provider_key, url))
|
||||||
for display, value in BASE_URLS
|
for display, provider_key, url in PROVIDERS
|
||||||
],
|
],
|
||||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||||
style=questionary.Style(
|
style=questionary.Style(
|
||||||
|
|
@ -261,13 +261,11 @@ def select_llm_provider() -> tuple[str, str | None]:
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
if choice is None:
|
||||||
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
|
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
display_name, url = choice
|
|
||||||
print(f"You selected: {display_name}\tURL: {url}")
|
|
||||||
|
|
||||||
return display_name, url
|
provider, url = choice
|
||||||
|
return provider, url
|
||||||
|
|
||||||
|
|
||||||
def ask_openai_reasoning_effort() -> str:
|
def ask_openai_reasoning_effort() -> str:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ services:
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
volumes:
|
volumes:
|
||||||
- ./results:/home/appuser/app/results
|
- tradingagents_data:/home/appuser/.tradingagents
|
||||||
tty: true
|
tty: true
|
||||||
stdin_open: true
|
stdin_open: true
|
||||||
|
|
||||||
|
|
@ -22,7 +22,7 @@ services:
|
||||||
environment:
|
environment:
|
||||||
- LLM_PROVIDER=ollama
|
- LLM_PROVIDER=ollama
|
||||||
volumes:
|
volumes:
|
||||||
- ./results:/home/appuser/app/results
|
- tradingagents_data:/home/appuser/.tradingagents
|
||||||
depends_on:
|
depends_on:
|
||||||
- ollama
|
- ollama
|
||||||
tty: true
|
tty: true
|
||||||
|
|
@ -31,4 +31,5 @@ services:
|
||||||
- ollama
|
- ollama
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
|
tradingagents_data:
|
||||||
ollama_data:
|
ollama_data:
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ class FakeMessage:
|
||||||
class FakeMessageBuffer:
|
class FakeMessageBuffer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._processed_message_ids = set()
|
self._processed_message_ids = set()
|
||||||
|
self._processed_message_fingerprints = set()
|
||||||
self.messages = []
|
self.messages = []
|
||||||
self.tool_calls = []
|
self.tool_calls = []
|
||||||
|
|
||||||
|
|
@ -62,3 +63,29 @@ def test_ingest_chunk_messages_skips_duplicate_message_ids():
|
||||||
|
|
||||||
assert len(message_buffer.messages) == 1
|
assert len(message_buffer.messages) == 1
|
||||||
assert len(message_buffer.tool_calls) == 1
|
assert len(message_buffer.tool_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_ingest_chunk_messages_skips_duplicate_messages_without_ids():
|
||||||
|
message_buffer = FakeMessageBuffer()
|
||||||
|
chunk = {"messages": [FakeMessage(None, "same", [{"name": "tool_a", "args": {"x": 1}}])]}
|
||||||
|
|
||||||
|
ingest_chunk_messages(message_buffer, chunk, fake_classifier)
|
||||||
|
ingest_chunk_messages(message_buffer, chunk, fake_classifier)
|
||||||
|
|
||||||
|
assert len(message_buffer.messages) == 1
|
||||||
|
assert len(message_buffer.tool_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_ingest_chunk_messages_keeps_distinct_messages_without_ids():
|
||||||
|
message_buffer = FakeMessageBuffer()
|
||||||
|
chunk = {
|
||||||
|
"messages": [
|
||||||
|
FakeMessage(None, "first", [{"name": "tool_a", "args": {"x": 1}}]),
|
||||||
|
FakeMessage(None, "second", [{"name": "tool_b", "args": {"y": 2}}]),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
ingest_chunk_messages(message_buffer, chunk, fake_classifier)
|
||||||
|
|
||||||
|
assert [content for _, content in message_buffer.messages] == ["first", "second"]
|
||||||
|
assert [name for name, _ in message_buffer.tool_calls] == ["tool_a", "tool_b"]
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ class FinancialSituationMemory:
|
||||||
|
|
||||||
# Build results
|
# Build results
|
||||||
results = []
|
results = []
|
||||||
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores
|
max_score = float(scores.max()) if len(scores) > 0 and scores.max() > 0 else 1.0
|
||||||
|
|
||||||
for idx in top_indices:
|
for idx in top_indices:
|
||||||
# Normalize score to 0-1 range for consistency
|
# Normalize score to 0-1 range for consistency
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
_TRADINGAGENTS_HOME = os.path.join(os.path.expanduser("~"), ".tradingagents")
|
||||||
|
|
||||||
DEFAULT_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
|
||||||
"data_cache_dir": os.path.join(
|
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
|
||||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
|
||||||
"dataflows/data_cache",
|
|
||||||
),
|
|
||||||
# LLM settings
|
# LLM settings
|
||||||
"llm_provider": "openai",
|
"llm_provider": "openai",
|
||||||
"deep_think_llm": "gpt-5.4",
|
"deep_think_llm": "gpt-5.4",
|
||||||
|
|
|
||||||
|
|
@ -66,10 +66,8 @@ class TradingAgentsGraph:
|
||||||
set_config(self.config)
|
set_config(self.config)
|
||||||
|
|
||||||
# Create necessary directories
|
# Create necessary directories
|
||||||
os.makedirs(
|
os.makedirs(self.config["data_cache_dir"], exist_ok=True)
|
||||||
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
|
os.makedirs(self.config["results_dir"], exist_ok=True)
|
||||||
exist_ok=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize LLMs with provider-specific thinking configuration
|
# Initialize LLMs with provider-specific thinking configuration
|
||||||
llm_kwargs = self._get_provider_kwargs()
|
llm_kwargs = self._get_provider_kwargs()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
import os
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
|
from .validators import validate_model
|
||||||
|
|
||||||
|
_PASSTHROUGH_KWARGS = (
|
||||||
|
"timeout", "max_retries", "api_key", "reasoning_effort",
|
||||||
|
"callbacks", "http_client", "http_async_client",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizedAzureChatOpenAI(AzureChatOpenAI):
|
||||||
|
"""AzureChatOpenAI with normalized content output."""
|
||||||
|
|
||||||
|
def invoke(self, input, config=None, **kwargs):
|
||||||
|
return normalize_content(super().invoke(input, config, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIClient(BaseLLMClient):
|
||||||
|
"""Client for Azure OpenAI deployments.
|
||||||
|
|
||||||
|
Requires environment variables:
|
||||||
|
AZURE_OPENAI_API_KEY: API key
|
||||||
|
AZURE_OPENAI_ENDPOINT: Endpoint URL (e.g. https://<resource>.openai.azure.com/)
|
||||||
|
AZURE_OPENAI_DEPLOYMENT_NAME: Deployment name
|
||||||
|
OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
"""Return configured AzureChatOpenAI instance."""
|
||||||
|
self.warn_if_unknown_model()
|
||||||
|
|
||||||
|
llm_kwargs = {
|
||||||
|
"model": self.model,
|
||||||
|
"azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", self.model),
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in _PASSTHROUGH_KWARGS:
|
||||||
|
if key in self.kwargs:
|
||||||
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
return NormalizedAzureChatOpenAI(**llm_kwargs)
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
"""Azure accepts any deployed model name."""
|
||||||
|
return True
|
||||||
|
|
@ -4,6 +4,12 @@ from .base_client import BaseLLMClient
|
||||||
from .openai_client import OpenAIClient
|
from .openai_client import OpenAIClient
|
||||||
from .anthropic_client import AnthropicClient
|
from .anthropic_client import AnthropicClient
|
||||||
from .google_client import GoogleClient
|
from .google_client import GoogleClient
|
||||||
|
from .azure_client import AzureOpenAIClient
|
||||||
|
|
||||||
|
# Providers that use the OpenAI-compatible chat completions API
|
||||||
|
_OPENAI_COMPATIBLE = (
|
||||||
|
"openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_llm_client(
|
def create_llm_client(
|
||||||
|
|
@ -15,16 +21,10 @@ def create_llm_client(
|
||||||
"""Create an LLM client for the specified provider.
|
"""Create an LLM client for the specified provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
provider: LLM provider name
|
||||||
model: Model name/identifier
|
model: Model name/identifier
|
||||||
base_url: Optional base URL for API endpoint
|
base_url: Optional base URL for API endpoint
|
||||||
**kwargs: Additional provider-specific arguments
|
**kwargs: Additional provider-specific arguments
|
||||||
- http_client: Custom httpx.Client for SSL proxy or certificate customization
|
|
||||||
- http_async_client: Custom httpx.AsyncClient for async operations
|
|
||||||
- timeout: Request timeout in seconds
|
|
||||||
- max_retries: Maximum retry attempts
|
|
||||||
- api_key: API key for the provider
|
|
||||||
- callbacks: LangChain callbacks
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Configured BaseLLMClient instance
|
Configured BaseLLMClient instance
|
||||||
|
|
@ -34,16 +34,16 @@ def create_llm_client(
|
||||||
"""
|
"""
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
if provider_lower in ("openai", "ollama", "openrouter"):
|
if provider_lower in _OPENAI_COMPATIBLE:
|
||||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||||
|
|
||||||
if provider_lower == "xai":
|
|
||||||
return OpenAIClient(model, base_url, provider="xai", **kwargs)
|
|
||||||
|
|
||||||
if provider_lower == "anthropic":
|
if provider_lower == "anthropic":
|
||||||
return AnthropicClient(model, base_url, **kwargs)
|
return AnthropicClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
if provider_lower == "google":
|
if provider_lower == "google":
|
||||||
return GoogleClient(model, base_url, **kwargs)
|
return GoogleClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
if provider_lower == "azure":
|
||||||
|
return AzureOpenAIClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -63,8 +63,43 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
||||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
# OpenRouter models are fetched dynamically at CLI runtime.
|
"deepseek": {
|
||||||
# No static entries needed; any model ID is accepted by the validator.
|
"quick": [
|
||||||
|
("DeepSeek V3.2", "deepseek-chat"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("DeepSeek V3.2 (thinking)", "deepseek-reasoner"),
|
||||||
|
("DeepSeek V3.2", "deepseek-chat"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"qwen": {
|
||||||
|
"quick": [
|
||||||
|
("Qwen 3.5 Flash", "qwen3.5-flash"),
|
||||||
|
("Qwen Plus", "qwen-plus"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("Qwen 3.6 Plus", "qwen3.6-plus"),
|
||||||
|
("Qwen 3.5 Plus", "qwen3.5-plus"),
|
||||||
|
("Qwen 3 Max", "qwen3-max"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"glm": {
|
||||||
|
"quick": [
|
||||||
|
("GLM-4.7", "glm-4.7"),
|
||||||
|
("GLM-5", "glm-5"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("GLM-5.1", "glm-5.1"),
|
||||||
|
("GLM-5", "glm-5"),
|
||||||
|
("Custom model ID", "custom"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# OpenRouter: fetched dynamically. Azure: any deployed model name.
|
||||||
"ollama": {
|
"ollama": {
|
||||||
"quick": [
|
"quick": [
|
||||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,9 @@ _PASSTHROUGH_KWARGS = (
|
||||||
# Provider base URLs and API key env vars
|
# Provider base URLs and API key env vars
|
||||||
_PROVIDER_CONFIG = {
|
_PROVIDER_CONFIG = {
|
||||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
||||||
|
"deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"),
|
||||||
|
"qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"),
|
||||||
|
"glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"),
|
||||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
||||||
"ollama": ("http://localhost:11434/v1", None),
|
"ollama": ("http://localhost:11434/v1", None),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue