fix: workflow flaw with openrouter

This commit is contained in:
mogita 2025-08-16 15:08:08 +08:00
parent 10f5b9bf12
commit 84cf5fbfea
No known key found for this signature in database
GPG Key ID: A0AA1B9C57A48ECF
3 changed files with 36 additions and 35 deletions

View File

@ -1,7 +1,6 @@
import chromadb
from chromadb.config import Settings
from openai import OpenAI
import os
from tradingagents.utils.provider_utils import get_openai_client
class FinancialSituationMemory:
@ -11,14 +10,7 @@ class FinancialSituationMemory:
else:
self.embedding = "text-embedding-3-small"
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY
provider = config.get("llm_provider", "openai").lower()
if provider.startswith("custom"):
api_key = os.getenv("CUSTOM_API_KEY")
else:
api_key = os.getenv("OPENAI_API_KEY")
self.client = OpenAI(base_url=config["backend_url"], api_key=api_key)
self.client = get_openai_client(config)
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)

View File

@ -12,7 +12,6 @@ import os
import pandas as pd
from tqdm import tqdm
import yfinance as yf
from openai import OpenAI
from .config import get_config, set_config, DATA_DIR
@ -704,14 +703,8 @@ def get_YFin_data(
def get_stock_news_openai(ticker, curr_date):
config = get_config()
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY
provider = config.get("llm_provider", "openai").lower()
if provider.startswith("custom"):
api_key = os.getenv("CUSTOM_API_KEY")
else:
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(base_url=config["backend_url"], api_key=api_key)
from tradingagents.utils.provider_utils import get_openai_client
client = get_openai_client(config)
response = client.responses.create(
model=config["quick_think_llm"],
@ -746,14 +739,8 @@ def get_stock_news_openai(ticker, curr_date):
def get_global_news_openai(curr_date):
config = get_config()
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY
provider = config.get("llm_provider", "openai").lower()
if provider.startswith("custom"):
api_key = os.getenv("CUSTOM_API_KEY")
else:
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(base_url=config["backend_url"], api_key=api_key)
from tradingagents.utils.provider_utils import get_openai_client
client = get_openai_client(config)
response = client.responses.create(
model=config["quick_think_llm"],
@ -788,14 +775,8 @@ def get_global_news_openai(curr_date):
def get_fundamentals_openai(ticker, curr_date):
config = get_config()
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY
provider = config.get("llm_provider", "openai").lower()
if provider.startswith("custom"):
api_key = os.getenv("CUSTOM_API_KEY")
else:
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(base_url=config["backend_url"], api_key=api_key)
from tradingagents.utils.provider_utils import get_openai_client
client = get_openai_client(config)
response = client.responses.create(
model=config["quick_think_llm"],

View File

@ -16,6 +16,13 @@ def get_api_key_for_provider(config):
"""
provider = config.get("llm_provider", "openai").lower()
# Handle custom provider first
if provider.startswith("custom"):
api_key = os.getenv("CUSTOM_API_KEY")
if not api_key:
print("Warning: CUSTOM_API_KEY not found in environment variables")
return api_key
# Map providers to their environment variables
api_key_mapping = {
"openai": "OPENAI_API_KEY",
@ -32,3 +39,24 @@ def get_api_key_for_provider(config):
print(f"Warning: {env_var} not found in environment variables")
return api_key
def get_openai_client(config):
"""Get a properly configured OpenAI client based on the provider configuration.
This function centralizes OpenAI client creation with correct API key resolution
for all providers that use OpenAI-compatible interfaces (OpenAI, OpenRouter,
Ollama, and custom providers).
Args:
config (dict): Configuration dictionary containing llm_provider and backend_url
Returns:
OpenAI: Configured OpenAI client instance
"""
from openai import OpenAI
api_key = get_api_key_for_provider(config)
backend_url = config.get("backend_url", "https://api.openai.com/v1")
return OpenAI(base_url=backend_url, api_key=api_key)