fix: workflow flaw with openrouter
This commit is contained in:
parent
10f5b9bf12
commit
84cf5fbfea
|
|
@ -1,7 +1,6 @@
|
|||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from openai import OpenAI
|
||||
import os
|
||||
from tradingagents.utils.provider_utils import get_openai_client
|
||||
|
||||
|
||||
class FinancialSituationMemory:
|
||||
|
|
@ -11,14 +10,7 @@ class FinancialSituationMemory:
|
|||
else:
|
||||
self.embedding = "text-embedding-3-small"
|
||||
|
||||
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY
|
||||
provider = config.get("llm_provider", "openai").lower()
|
||||
if provider.startswith("custom"):
|
||||
api_key = os.getenv("CUSTOM_API_KEY")
|
||||
else:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
self.client = OpenAI(base_url=config["backend_url"], api_key=api_key)
|
||||
self.client = get_openai_client(config)
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import os
|
|||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
import yfinance as yf
|
||||
from openai import OpenAI
|
||||
from .config import get_config, set_config, DATA_DIR
|
||||
|
||||
|
||||
|
|
@ -704,14 +703,8 @@ def get_YFin_data(
|
|||
|
||||
def get_stock_news_openai(ticker, curr_date):
|
||||
config = get_config()
|
||||
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY
|
||||
provider = config.get("llm_provider", "openai").lower()
|
||||
if provider.startswith("custom"):
|
||||
api_key = os.getenv("CUSTOM_API_KEY")
|
||||
else:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
client = OpenAI(base_url=config["backend_url"], api_key=api_key)
|
||||
from tradingagents.utils.provider_utils import get_openai_client
|
||||
client = get_openai_client(config)
|
||||
|
||||
response = client.responses.create(
|
||||
model=config["quick_think_llm"],
|
||||
|
|
@ -746,14 +739,8 @@ def get_stock_news_openai(ticker, curr_date):
|
|||
|
||||
def get_global_news_openai(curr_date):
|
||||
config = get_config()
|
||||
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY
|
||||
provider = config.get("llm_provider", "openai").lower()
|
||||
if provider.startswith("custom"):
|
||||
api_key = os.getenv("CUSTOM_API_KEY")
|
||||
else:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
client = OpenAI(base_url=config["backend_url"], api_key=api_key)
|
||||
from tradingagents.utils.provider_utils import get_openai_client
|
||||
client = get_openai_client(config)
|
||||
|
||||
response = client.responses.create(
|
||||
model=config["quick_think_llm"],
|
||||
|
|
@ -788,14 +775,8 @@ def get_global_news_openai(curr_date):
|
|||
|
||||
def get_fundamentals_openai(ticker, curr_date):
|
||||
config = get_config()
|
||||
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY
|
||||
provider = config.get("llm_provider", "openai").lower()
|
||||
if provider.startswith("custom"):
|
||||
api_key = os.getenv("CUSTOM_API_KEY")
|
||||
else:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
client = OpenAI(base_url=config["backend_url"], api_key=api_key)
|
||||
from tradingagents.utils.provider_utils import get_openai_client
|
||||
client = get_openai_client(config)
|
||||
|
||||
response = client.responses.create(
|
||||
model=config["quick_think_llm"],
|
||||
|
|
|
|||
|
|
@ -16,6 +16,13 @@ def get_api_key_for_provider(config):
|
|||
"""
|
||||
provider = config.get("llm_provider", "openai").lower()
|
||||
|
||||
# Handle custom provider first
|
||||
if provider.startswith("custom"):
|
||||
api_key = os.getenv("CUSTOM_API_KEY")
|
||||
if not api_key:
|
||||
print("Warning: CUSTOM_API_KEY not found in environment variables")
|
||||
return api_key
|
||||
|
||||
# Map providers to their environment variables
|
||||
api_key_mapping = {
|
||||
"openai": "OPENAI_API_KEY",
|
||||
|
|
@ -32,3 +39,24 @@ def get_api_key_for_provider(config):
|
|||
print(f"Warning: {env_var} not found in environment variables")
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
def get_openai_client(config):
|
||||
"""Get a properly configured OpenAI client based on the provider configuration.
|
||||
|
||||
This function centralizes OpenAI client creation with correct API key resolution
|
||||
for all providers that use OpenAI-compatible interfaces (OpenAI, OpenRouter,
|
||||
Ollama, and custom providers).
|
||||
|
||||
Args:
|
||||
config (dict): Configuration dictionary containing llm_provider and backend_url
|
||||
|
||||
Returns:
|
||||
OpenAI: Configured OpenAI client instance
|
||||
"""
|
||||
from openai import OpenAI
|
||||
|
||||
api_key = get_api_key_for_provider(config)
|
||||
backend_url = config.get("backend_url", "https://api.openai.com/v1")
|
||||
|
||||
return OpenAI(base_url=backend_url, api_key=api_key)
|
||||
|
|
|
|||
Loading…
Reference in New Issue