fix: workflow flaw with openrouter
This commit is contained in:
parent
10f5b9bf12
commit
84cf5fbfea
|
|
@ -1,7 +1,6 @@
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from openai import OpenAI
|
from tradingagents.utils.provider_utils import get_openai_client
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
class FinancialSituationMemory:
|
class FinancialSituationMemory:
|
||||||
|
|
@ -11,14 +10,7 @@ class FinancialSituationMemory:
|
||||||
else:
|
else:
|
||||||
self.embedding = "text-embedding-3-small"
|
self.embedding = "text-embedding-3-small"
|
||||||
|
|
||||||
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY
|
self.client = get_openai_client(config)
|
||||||
provider = config.get("llm_provider", "openai").lower()
|
|
||||||
if provider.startswith("custom"):
|
|
||||||
api_key = os.getenv("CUSTOM_API_KEY")
|
|
||||||
else:
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
|
|
||||||
self.client = OpenAI(base_url=config["backend_url"], api_key=api_key)
|
|
||||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ import os
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
from openai import OpenAI
|
|
||||||
from .config import get_config, set_config, DATA_DIR
|
from .config import get_config, set_config, DATA_DIR
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -704,14 +703,8 @@ def get_YFin_data(
|
||||||
|
|
||||||
def get_stock_news_openai(ticker, curr_date):
|
def get_stock_news_openai(ticker, curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY
|
from tradingagents.utils.provider_utils import get_openai_client
|
||||||
provider = config.get("llm_provider", "openai").lower()
|
client = get_openai_client(config)
|
||||||
if provider.startswith("custom"):
|
|
||||||
api_key = os.getenv("CUSTOM_API_KEY")
|
|
||||||
else:
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
|
|
||||||
client = OpenAI(base_url=config["backend_url"], api_key=api_key)
|
|
||||||
|
|
||||||
response = client.responses.create(
|
response = client.responses.create(
|
||||||
model=config["quick_think_llm"],
|
model=config["quick_think_llm"],
|
||||||
|
|
@ -746,14 +739,8 @@ def get_stock_news_openai(ticker, curr_date):
|
||||||
|
|
||||||
def get_global_news_openai(curr_date):
|
def get_global_news_openai(curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY
|
from tradingagents.utils.provider_utils import get_openai_client
|
||||||
provider = config.get("llm_provider", "openai").lower()
|
client = get_openai_client(config)
|
||||||
if provider.startswith("custom"):
|
|
||||||
api_key = os.getenv("CUSTOM_API_KEY")
|
|
||||||
else:
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
|
|
||||||
client = OpenAI(base_url=config["backend_url"], api_key=api_key)
|
|
||||||
|
|
||||||
response = client.responses.create(
|
response = client.responses.create(
|
||||||
model=config["quick_think_llm"],
|
model=config["quick_think_llm"],
|
||||||
|
|
@ -788,14 +775,8 @@ def get_global_news_openai(curr_date):
|
||||||
|
|
||||||
def get_fundamentals_openai(ticker, curr_date):
|
def get_fundamentals_openai(ticker, curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
# Use CUSTOM_API_KEY if provider is custom, otherwise use OPENAI_API_KEY
|
from tradingagents.utils.provider_utils import get_openai_client
|
||||||
provider = config.get("llm_provider", "openai").lower()
|
client = get_openai_client(config)
|
||||||
if provider.startswith("custom"):
|
|
||||||
api_key = os.getenv("CUSTOM_API_KEY")
|
|
||||||
else:
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
|
|
||||||
client = OpenAI(base_url=config["backend_url"], api_key=api_key)
|
|
||||||
|
|
||||||
response = client.responses.create(
|
response = client.responses.create(
|
||||||
model=config["quick_think_llm"],
|
model=config["quick_think_llm"],
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,13 @@ def get_api_key_for_provider(config):
|
||||||
"""
|
"""
|
||||||
provider = config.get("llm_provider", "openai").lower()
|
provider = config.get("llm_provider", "openai").lower()
|
||||||
|
|
||||||
|
# Handle custom provider first
|
||||||
|
if provider.startswith("custom"):
|
||||||
|
api_key = os.getenv("CUSTOM_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
print("Warning: CUSTOM_API_KEY not found in environment variables")
|
||||||
|
return api_key
|
||||||
|
|
||||||
# Map providers to their environment variables
|
# Map providers to their environment variables
|
||||||
api_key_mapping = {
|
api_key_mapping = {
|
||||||
"openai": "OPENAI_API_KEY",
|
"openai": "OPENAI_API_KEY",
|
||||||
|
|
@ -32,3 +39,24 @@ def get_api_key_for_provider(config):
|
||||||
print(f"Warning: {env_var} not found in environment variables")
|
print(f"Warning: {env_var} not found in environment variables")
|
||||||
|
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
def get_openai_client(config):
|
||||||
|
"""Get a properly configured OpenAI client based on the provider configuration.
|
||||||
|
|
||||||
|
This function centralizes OpenAI client creation with correct API key resolution
|
||||||
|
for all providers that use OpenAI-compatible interfaces (OpenAI, OpenRouter,
|
||||||
|
Ollama, and custom providers).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): Configuration dictionary containing llm_provider and backend_url
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAI: Configured OpenAI client instance
|
||||||
|
"""
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
api_key = get_api_key_for_provider(config)
|
||||||
|
backend_url = config.get("backend_url", "https://api.openai.com/v1")
|
||||||
|
|
||||||
|
return OpenAI(base_url=backend_url, api_key=api_key)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue