From 84cf5fbfeabdb93c44a98d297db28979b18d6c8d Mon Sep 17 00:00:00 2001 From: mogita Date: Sat, 16 Aug 2025 15:08:08 +0800 Subject: [PATCH] fix: workflow flaw with openrouter --- tradingagents/agents/utils/memory.py | 12 ++--------- tradingagents/dataflows/interface.py | 31 ++++++--------------------- tradingagents/utils/provider_utils.py | 28 ++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 35 deletions(-) diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 5e14c2ef..76ca0a8e 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,7 +1,6 @@ import chromadb from chromadb.config import Settings -from openai import OpenAI -import os +from tradingagents.utils.provider_utils import get_openai_client class FinancialSituationMemory: @@ -11,14 +10,7 @@ class FinancialSituationMemory: else: self.embedding = "text-embedding-3-small" - # Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY - provider = config.get("llm_provider", "openai").lower() - if provider.startswith("custom"): - api_key = os.getenv("CUSTOM_API_KEY") - else: - api_key = os.getenv("OPENAI_API_KEY") - - self.client = OpenAI(base_url=config["backend_url"], api_key=api_key) + self.client = get_openai_client(config) self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.situation_collection = self.chroma_client.create_collection(name=name) diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 2c73baf0..d6f0c7a2 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -12,7 +12,6 @@ import os import pandas as pd from tqdm import tqdm import yfinance as yf -from openai import OpenAI from .config import get_config, set_config, DATA_DIR @@ -704,14 +703,8 @@ def get_YFin_data( def get_stock_news_openai(ticker, curr_date): config = get_config() - # Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY - provider = config.get("llm_provider", "openai").lower() - if provider.startswith("custom"): - api_key = os.getenv("CUSTOM_API_KEY") - else: - api_key = os.getenv("OPENAI_API_KEY") - - client = OpenAI(base_url=config["backend_url"], api_key=api_key) + from tradingagents.utils.provider_utils import get_openai_client + client = get_openai_client(config) response = client.responses.create( model=config["quick_think_llm"], @@ -746,14 +739,8 @@ def get_stock_news_openai(ticker, curr_date): def get_global_news_openai(curr_date): config = get_config() - # Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY - provider = config.get("llm_provider", "openai").lower() - if provider.startswith("custom"): - api_key = os.getenv("CUSTOM_API_KEY") - else: - api_key = os.getenv("OPENAI_API_KEY") - - client = OpenAI(base_url=config["backend_url"], api_key=api_key) + from tradingagents.utils.provider_utils import get_openai_client + client = get_openai_client(config) response = client.responses.create( model=config["quick_think_llm"], @@ -788,14 +775,8 @@ def get_global_news_openai(curr_date): def get_fundamentals_openai(ticker, curr_date): config = get_config() - # Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY - provider = config.get("llm_provider", "openai").lower() - if provider.startswith("custom"): - api_key = os.getenv("CUSTOM_API_KEY") - else: - api_key = os.getenv("OPENAI_API_KEY") - - client = OpenAI(base_url=config["backend_url"], api_key=api_key) + from tradingagents.utils.provider_utils import get_openai_client + client = get_openai_client(config) response = client.responses.create( model=config["quick_think_llm"], diff --git a/tradingagents/utils/provider_utils.py b/tradingagents/utils/provider_utils.py index 426a0a39..9beaad89 100644 --- a/tradingagents/utils/provider_utils.py +++ b/tradingagents/utils/provider_utils.py @@ -16,6 +16,13 @@ def get_api_key_for_provider(config): """ provider = config.get("llm_provider", "openai").lower() + # Handle custom provider first + if provider.startswith("custom"): + api_key = os.getenv("CUSTOM_API_KEY") + if not api_key: + print("Warning: CUSTOM_API_KEY not found in environment variables") + return api_key + # Map providers to their environment variables api_key_mapping = { "openai": "OPENAI_API_KEY", @@ -32,3 +39,24 @@ def get_api_key_for_provider(config): print(f"Warning: {env_var} not found in environment variables") return api_key + + +def get_openai_client(config): + """Get a properly configured OpenAI client based on the provider configuration. + + This function centralizes OpenAI client creation with correct API key resolution + for all providers that use OpenAI-compatible interfaces (OpenAI, OpenRouter, + Ollama, and custom providers). + + Args: + config (dict): Configuration dictionary containing llm_provider and backend_url + + Returns: + OpenAI: Configured OpenAI client instance + """ + from openai import OpenAI + + api_key = get_api_key_for_provider(config) + backend_url = config.get("backend_url", "https://api.openai.com/v1") + + return OpenAI(base_url=backend_url, api_key=api_key)