fix: workflow flaw with openrouter

This commit is contained in:
mogita 2025-08-16 15:08:08 +08:00
parent 10f5b9bf12
commit 84cf5fbfea
No known key found for this signature in database
GPG Key ID: A0AA1B9C57A48ECF
3 changed files with 36 additions and 35 deletions

View File

@ -1,7 +1,6 @@
import chromadb import chromadb
from chromadb.config import Settings from chromadb.config import Settings
from openai import OpenAI from tradingagents.utils.provider_utils import get_openai_client
import os
class FinancialSituationMemory: class FinancialSituationMemory:
@ -11,14 +10,7 @@ class FinancialSituationMemory:
else: else:
self.embedding = "text-embedding-3-small" self.embedding = "text-embedding-3-small"
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY self.client = get_openai_client(config)
provider = config.get("llm_provider", "openai").lower()
if provider.startswith("custom"):
api_key = os.getenv("CUSTOM_API_KEY")
else:
api_key = os.getenv("OPENAI_API_KEY")
self.client = OpenAI(base_url=config["backend_url"], api_key=api_key)
self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name) self.situation_collection = self.chroma_client.create_collection(name=name)

View File

@ -12,7 +12,6 @@ import os
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
import yfinance as yf import yfinance as yf
from openai import OpenAI
from .config import get_config, set_config, DATA_DIR from .config import get_config, set_config, DATA_DIR
@ -704,14 +703,8 @@ def get_YFin_data(
def get_stock_news_openai(ticker, curr_date): def get_stock_news_openai(ticker, curr_date):
config = get_config() config = get_config()
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY from tradingagents.utils.provider_utils import get_openai_client
provider = config.get("llm_provider", "openai").lower() client = get_openai_client(config)
if provider.startswith("custom"):
api_key = os.getenv("CUSTOM_API_KEY")
else:
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(base_url=config["backend_url"], api_key=api_key)
response = client.responses.create( response = client.responses.create(
model=config["quick_think_llm"], model=config["quick_think_llm"],
@ -746,14 +739,8 @@ def get_stock_news_openai(ticker, curr_date):
def get_global_news_openai(curr_date): def get_global_news_openai(curr_date):
config = get_config() config = get_config()
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY from tradingagents.utils.provider_utils import get_openai_client
provider = config.get("llm_provider", "openai").lower() client = get_openai_client(config)
if provider.startswith("custom"):
api_key = os.getenv("CUSTOM_API_KEY")
else:
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(base_url=config["backend_url"], api_key=api_key)
response = client.responses.create( response = client.responses.create(
model=config["quick_think_llm"], model=config["quick_think_llm"],
@ -788,14 +775,8 @@ def get_global_news_openai(curr_date):
def get_fundamentals_openai(ticker, curr_date): def get_fundamentals_openai(ticker, curr_date):
config = get_config() config = get_config()
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY from tradingagents.utils.provider_utils import get_openai_client
provider = config.get("llm_provider", "openai").lower() client = get_openai_client(config)
if provider.startswith("custom"):
api_key = os.getenv("CUSTOM_API_KEY")
else:
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(base_url=config["backend_url"], api_key=api_key)
response = client.responses.create( response = client.responses.create(
model=config["quick_think_llm"], model=config["quick_think_llm"],

View File

@ -16,6 +16,13 @@ def get_api_key_for_provider(config):
""" """
provider = config.get("llm_provider", "openai").lower() provider = config.get("llm_provider", "openai").lower()
# Handle custom provider first
if provider.startswith("custom"):
api_key = os.getenv("CUSTOM_API_KEY")
if not api_key:
print("Warning: CUSTOM_API_KEY not found in environment variables")
return api_key
# Map providers to their environment variables # Map providers to their environment variables
api_key_mapping = { api_key_mapping = {
"openai": "OPENAI_API_KEY", "openai": "OPENAI_API_KEY",
@ -32,3 +39,24 @@ def get_api_key_for_provider(config):
print(f"Warning: {env_var} not found in environment variables") print(f"Warning: {env_var} not found in environment variables")
return api_key return api_key
def get_openai_client(config):
"""Get a properly configured OpenAI client based on the provider configuration.
This function centralizes OpenAI client creation with correct API key resolution
for all providers that use OpenAI-compatible interfaces (OpenAI, OpenRouter,
Ollama, and custom providers).
Args:
config (dict): Configuration dictionary containing llm_provider and backend_url
Returns:
OpenAI: Configured OpenAI client instance
"""
from openai import OpenAI
api_key = get_api_key_for_provider(config)
backend_url = config.get("backend_url", "https://api.openai.com/v1")
return OpenAI(base_url=backend_url, api_key=api_key)