This commit is contained in:
🐈 Yun Wang 2025-08-16 17:14:07 +08:00 committed by GitHub
commit 54e299f039
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 397 additions and 122 deletions

2
.gitignore vendored
View File

@ -7,3 +7,5 @@ eval_results/
eval_data/
*.egg-info/
.env
*.log
results/*

View File

@ -1,5 +1,4 @@
from typing import Optional
import datetime
from datetime import datetime
import typer
from pathlib import Path
from functools import wraps
@ -14,16 +13,23 @@ from rich.text import Text
from rich.live import Live
from rich.table import Table
from collections import deque
import time
from rich.tree import Tree
from rich import box
from rich.align import Align
from rich.rule import Rule
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from cli.models import AnalystType
from cli.utils import *
from cli.utils import (
select_analysts,
select_research_depth,
select_shallow_thinking_agent,
select_deep_thinking_agent,
select_llm_provider
)
console = Console()
@ -72,11 +78,11 @@ class MessageBuffer:
}
def add_message(self, message_type, content):
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
timestamp = datetime.now().strftime("%H:%M:%S")
self.messages.append((timestamp, message_type, content))
def add_tool_call(self, tool_name, args):
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
timestamp = datetime.now().strftime("%H:%M:%S")
self.tool_calls.append((timestamp, tool_name, args))
def update_agent_status(self, agent, status):
@ -434,7 +440,7 @@ def get_user_selections():
selected_ticker = get_ticker()
# Step 2: Analysis date
default_date = datetime.datetime.now().strftime("%Y-%m-%d")
default_date = datetime.now().strftime("%Y-%m-%d")
console.print(
create_question_box(
"Step 2: Analysis Date",
@ -501,12 +507,12 @@ def get_analysis_date():
"""Get the analysis date from user input."""
while True:
date_str = typer.prompt(
"", default=datetime.datetime.now().strftime("%Y-%m-%d")
"", default=datetime.now().strftime("%Y-%m-%d")
)
try:
# Validate date format and ensure it's not in the future
analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
if analysis_date.date() > datetime.datetime.now().date():
analysis_date = datetime.strptime(date_str, "%Y-%m-%d")
if analysis_date.date() > datetime.now().date():
console.print("[red]Error: Analysis date cannot be in the future[/red]")
continue
return date_str

View File

@ -1,8 +1,17 @@
import os
import questionary
from typing import List, Optional, Tuple, Dict
import re
from datetime import datetime
from rich.console import Console
from typing import List
from urllib.parse import urlparse
from cli.models import AnalystType
console = Console()
CUSTOM_MODEL_IDENTIFIER = "__CUSTOM_MODEL__"
ANALYST_ORDER = [
("Market Analyst", AnalystType.MARKET),
("Social Media Analyst", AnalystType.SOCIAL),
@ -33,8 +42,6 @@ def get_ticker() -> str:
def get_analysis_date() -> str:
"""Prompt the user to enter a date in YYYY-MM-DD format."""
import re
from datetime import datetime
def validate_date(date_str: str) -> bool:
if not re.match(r"^\d{4}-\d{2}-\d{2}$", date_str):
@ -122,44 +129,195 @@ def select_research_depth() -> int:
return choice
def select_shallow_thinking_agent(provider) -> str:
"""Select shallow thinking llm engine using an interactive selection."""
# Centralized model definitions - single source of truth
SHALLOW_AGENT_OPTIONS = {
"openai": [
("GPT-4o-mini - Fast and efficient for quick tasks", "gpt-4o-mini"),
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
],
"anthropic": [
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
],
"google": [
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
],
"openrouter": [
("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"),
("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"),
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
],
"ollama": [
("llama3.1 local", "llama3.1"),
("llama3.2 local", "llama3.2"),
]
}
# Define shallow thinking llm engine options with their corresponding model names
SHALLOW_AGENT_OPTIONS = {
"openai": [
("GPT-4o-mini - Fast and efficient for quick tasks", "gpt-4o-mini"),
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
],
"anthropic": [
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
],
"google": [
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
],
"openrouter": [
("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"),
("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"),
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
],
"ollama": [
("llama3.1 local", "llama3.1"),
("llama3.2 local", "llama3.2"),
]
}
DEEP_AGENT_OPTIONS = {
"openai": [
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
("o4-mini - Specialized reasoning model (compact)", "o4-mini"),
("o3-mini - Advanced reasoning model (lightweight)", "o3-mini"),
("o3 - Full advanced reasoning model", "o3"),
("o1 - Premier reasoning and problem-solving model", "o1"),
],
"anthropic": [
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
("Claude Opus 4 - Most powerful Anthropic model", "claude-opus-4-0"),
],
"google": [
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
("Gemini 2.5 Pro", "gemini-2.5-pro-preview-06-05"),
],
"openrouter": [
("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"),
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
],
"ollama": [
("llama3.1 local", "llama3.1"),
("qwen3", "qwen3"),
]
}
def _get_all_models_for_custom_provider(model_type: str) -> list:
"""Get unified model list for custom provider with all available models from all providers.
Args:
model_type: Either 'shallow' or 'deep' to get the appropriate model set
Returns:
List of (description, model_value) tuples
"""
# Use the centralized model definitions
if model_type == "shallow":
provider_models = SHALLOW_AGENT_OPTIONS
else: # deep
provider_models = DEEP_AGENT_OPTIONS
# Combine all models with provider labels
all_models = []
for provider_name, models in provider_models.items():
provider_display_name = provider_name.title()
for description, model_value in models:
labeled_description = f"{description} ({provider_display_name})"
all_models.append((labeled_description, model_value))
# Add custom model option at the end
all_models.append(("Custom Model - Enter your own model name", CUSTOM_MODEL_IDENTIFIER))
return all_models
def _select_custom_provider_model(model_type: str, title: str, default_model: str) -> str:
"""Handle model selection for custom provider with unified model list.
Args:
model_type: Either 'shallow' or 'deep'
title: Title for the selection prompt
default_model: Default model name for custom input
Returns:
Selected model name
"""
all_models = _get_all_models_for_custom_provider(model_type)
choice = questionary.select(
"Select Your [Quick-Thinking LLM Engine]:",
title,
choices=[
questionary.Choice(display, value=value)
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
for display, value in all_models
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select\n- Your custom endpoint should support the selected model",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
if choice is None:
raise ValueError(f"No {model_type} thinking model selected")
# Handle custom model input
if choice == CUSTOM_MODEL_IDENTIFIER:
custom_model = questionary.text(
f"Enter your custom {model_type} thinking model name:",
default=default_model,
instruction="\n- Enter the exact model name as supported by your custom endpoint\n- Press Enter to confirm"
).ask()
if not custom_model:
raise ValueError(f"No custom {model_type} model name entered")
return custom_model
return choice
def _select_thinking_agent(provider: str, model_type: str) -> str:
"""Unified function to select thinking agents for both shallow and deep models.
Args:
provider: The LLM provider name
model_type: Either 'shallow' or 'deep'
Returns:
str: The selected model name
"""
# Configuration for different model types
config = {
"shallow": {
"title": "Select Your [Quick-Thinking LLM Engine]:",
"custom_title": "Select Your [Quick-Thinking LLM Engine] (Custom Provider - All Models Available):",
"default_model": "gpt-4o-mini",
"options": SHALLOW_AGENT_OPTIONS,
"error_message": "No shallow thinking llm engine selected. Exiting..."
},
"deep": {
"title": "Select Your [Deep-Thinking LLM Engine]:",
"custom_title": "Select Your [Deep-Thinking LLM Engine] (Custom Provider - All Models Available):",
"default_model": "o4-mini",
"options": DEEP_AGENT_OPTIONS,
"error_message": "No deep thinking llm engine selected. Exiting..."
}
}
model_config = config[model_type]
# Handle custom provider - use unified model selection
if provider.lower().startswith("custom"):
try:
return _select_custom_provider_model(
model_type=model_type,
title=model_config["custom_title"],
default_model=model_config["default_model"]
)
except ValueError as e:
console.print(f"\n[red]Error: {e}[/red]")
exit(1)
# Use centralized model definitions
choice = questionary.select(
model_config["title"],
choices=[
questionary.Choice(display, value=value)
for display, value in model_config["options"][provider.lower()]
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
@ -172,83 +330,104 @@ def select_shallow_thinking_agent(provider) -> str:
).ask()
if choice is None:
console.print(
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
)
console.print(f"\n[red]{model_config['error_message']}[/red]")
exit(1)
return choice
def select_shallow_thinking_agent(provider) -> str:
"""Select shallow thinking llm engine using an interactive selection."""
return _select_thinking_agent(provider, "shallow")
def select_deep_thinking_agent(provider) -> str:
"""Select deep thinking llm engine using an interactive selection."""
return _select_thinking_agent(provider, "deep")
# Define deep thinking llm engine options with their corresponding model names
DEEP_AGENT_OPTIONS = {
"openai": [
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
("o4-mini - Specialized reasoning model (compact)", "o4-mini"),
("o3-mini - Advanced reasoning model (lightweight)", "o3-mini"),
("o3 - Full advanced reasoning model", "o3"),
("o1 - Premier reasoning and problem-solving model", "o1"),
],
"anthropic": [
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
("Claude Opus 4 - Most powerful Anthropic model", " claude-opus-4-0"),
],
"google": [
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
("Gemini 2.5 Pro", "gemini-2.5-pro-preview-06-05"),
],
"openrouter": [
("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"),
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
],
"ollama": [
("llama3.1 local", "llama3.1"),
("qwen3", "qwen3"),
]
}
choice = questionary.select(
"Select Your [Deep-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
def validate_custom_url(url: str) -> str:
"""Validate that a custom URL is properly formatted and has a valid hostname.
if choice is None:
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
exit(1)
Args:
url: The URL to validate
Returns:
str: The validated URL
Raises:
ValueError: If the URL is invalid or malformed
"""
if not url:
return ""
# Basic URL format validation
url_pattern = re.compile(
r'^https?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain...
r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
if not url_pattern.match(url):
raise ValueError(f"Invalid CUSTOM_BASE_URL format: {url}. Please provide a valid URL (e.g., https://api.example.com/v1)")
# Additional validation using urlparse
try:
parsed = urlparse(url)
if not parsed.netloc:
raise ValueError(f"Invalid CUSTOM_BASE_URL: {url}. No hostname found")
return url
except ValueError:
# Re-raise ValueError as-is
raise
except Exception as e:
raise ValueError(f"Invalid CUSTOM_BASE_URL: {url}. URL parsing error: {e}")
def get_custom_provider_info() -> tuple[str, str] | None:
"""Get custom provider info if both URL and API key are provided.
Returns:
tuple[str, str] | None: (display_name, url) if valid custom provider configured, None otherwise
Raises:
SystemExit: If custom URL is provided but invalid (exits with error message)
"""
custom_url = os.getenv("CUSTOM_BASE_URL")
custom_api_key = os.getenv("CUSTOM_API_KEY")
if custom_url and custom_api_key:
try:
validated_url = validate_custom_url(custom_url)
parsed = urlparse(validated_url)
hostname = parsed.netloc
return f"Custom ({hostname})", validated_url
except ValueError as e:
console.print(f"[red]Error: {e}[/red]")
exit(1)
return None
return choice
def select_llm_provider() -> tuple[str, str]:
"""Select the OpenAI api url using interactive selection."""
# Define OpenAI api options with their corresponding endpoints
"""Select the LLM provider with support for a custom OpenAI-compatible endpoint."""
# Define default providers
BASE_URLS = [
("OpenAI", "https://api.openai.com/v1"),
("Anthropic", "https://api.anthropic.com/"),
("Google", "https://generativelanguage.googleapis.com/v1"),
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
("Ollama", "http://localhost:11434/v1"),
]
# Add custom provider at the beginning if available
custom_info = get_custom_provider_info()
if custom_info:
BASE_URLS.insert(0, custom_info)
choice = questionary.select(
"Select your LLM Provider:",
@ -267,10 +446,10 @@ def select_llm_provider() -> tuple[str, str]:
).ask()
if choice is None:
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
exit(1)
display_name, url = choice
print(f"You selected: {display_name}\tURL: {url}")
return display_name, url

20
example.env Normal file
View File

@ -0,0 +1,20 @@
# Copy this to your .env file and modify the URLs and API keys as needed
# Custom OpenAI-Compatible Provider (optional)
# If provided, a "Custom" option will appear first in the provider list
# The custom endpoint must be OpenAI-compatible (REST API, not gRPC)
# CUSTOM_BASE_URL=https://www.example.com/v1
# CUSTOM_API_KEY=sk-your-custom-api-key-here
# Standard Provider API Keys, please replace with your own keys to use the corresponding provider
OPENAI_API_KEY=sk-your-openai-api-key-here
ANTHROPIC_API_KEY=sk-ant-your-anthropic-api-key-here
GOOGLE_API_KEY=your-google-api-key-here
OPENROUTER_API_KEY=sk-or-your-openrouter-api-key-here
# OLLAMA_API_KEY is usually not needed for local Ollama instances
# Other Configuration
FINNHUB_API_KEY=your-finnhub-api-key-here
# Optional, uncomment to modify
# TRADINGAGENTS_RESULTS_DIR=./results

View File

@ -24,3 +24,4 @@ rich
questionary
langchain_anthropic
langchain-google-genai
dotenv

View File

@ -1,6 +1,6 @@
import chromadb
from chromadb.config import Settings
from openai import OpenAI
from tradingagents.utils.provider_utils import get_openai_client
class FinancialSituationMemory:
@ -9,7 +9,8 @@ class FinancialSituationMemory:
self.embedding = "nomic-embed-text"
else:
self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"])
self.client = get_openai_client(config)
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)

View File

@ -12,9 +12,9 @@ import os
import pandas as pd
from tqdm import tqdm
import yfinance as yf
from openai import OpenAI
from .config import get_config, set_config, DATA_DIR
from tradingagents.utils.provider_utils import get_openai_client
def get_finnhub_news(
ticker: Annotated[
@ -704,7 +704,7 @@ def get_YFin_data(
def get_stock_news_openai(ticker, curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
client = get_openai_client(config)
response = client.responses.create(
model=config["quick_think_llm"],
@ -739,7 +739,7 @@ def get_stock_news_openai(ticker, curr_date):
def get_global_news_openai(curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
client = get_openai_client(config)
response = client.responses.create(
model=config["quick_think_llm"],
@ -774,7 +774,7 @@ def get_global_news_openai(curr_date):
def get_fundamentals_openai(ticker, curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
client = get_openai_client(config)
response = client.responses.create(
model=config["quick_think_llm"],

View File

@ -12,7 +12,7 @@ DEFAULT_CONFIG = {
"llm_provider": "openai",
"deep_think_llm": "o4-mini",
"quick_think_llm": "gpt-4o-mini",
"backend_url": "https://api.openai.com/v1",
"backend_url": "https://api.openai.com/v1", # Will be updated based on selected provider
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,

View File

@ -13,6 +13,7 @@ from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
from tradingagents.utils.provider_utils import get_api_key_for_provider
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.agents.utils.agent_states import (
@ -58,13 +59,16 @@ class TradingAgentsGraph:
)
# Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic":
provider = self.config["llm_provider"].lower()
if provider in ("openai", "ollama", "openrouter") or provider.startswith("custom"):
api_key = get_api_key_for_provider(self.config)
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"], api_key=api_key)
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"], api_key=api_key)
elif provider == "anthropic":
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "google":
elif provider == "google":
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
else:

View File

@ -0,0 +1,62 @@
"""
Utility functions for LLM provider configuration and API key management.
"""
import os
import sys
from openai import OpenAI
def get_api_key_for_provider(config):
"""Get the appropriate API key based on the provider.
Args:
config (dict): Configuration dictionary containing llm_provider
Returns:
str: The API key for the provider, or None if not found
"""
provider = config.get("llm_provider", "openai").lower()
# Handle custom provider first
if provider.startswith("custom"):
api_key = os.getenv("CUSTOM_API_KEY")
if not api_key:
print("Warning: CUSTOM_API_KEY not found in environment variables", file=sys.stderr)
return api_key
# Map providers to their environment variables
api_key_mapping = {
"openai": "OPENAI_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"google": "GOOGLE_API_KEY",
"openrouter": "OPENROUTER_API_KEY",
"ollama": "OLLAMA_API_KEY",
}
env_var = api_key_mapping.get(provider, "OPENAI_API_KEY")
api_key = os.getenv(env_var)
if not api_key and provider != "ollama": # Ollama typically doesn't need API keys
print(f"Warning: {env_var} not found in environment variables", file=sys.stderr)
return api_key
def get_openai_client(config):
"""Get a properly configured OpenAI client based on the provider configuration.
This function centralizes OpenAI client creation with correct API key resolution
for all providers that use OpenAI-compatible interfaces (OpenAI, OpenRouter,
Ollama, and custom providers).
Args:
config (dict): Configuration dictionary containing llm_provider and backend_url
Returns:
OpenAI: Configured OpenAI client instance
"""
api_key = get_api_key_for_provider(config)
backend_url = config.get("backend_url", "https://api.openai.com/v1")
return OpenAI(base_url=backend_url, api_key=api_key)