diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 8f492137..092fb44b 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,3 +1,4 @@ +import os import chromadb from chromadb.config import Settings from openai import OpenAI @@ -9,7 +10,9 @@ class FinancialSituationMemory: self.embedding = "nomic-embed-text" else: self.embedding = "text-embedding-3-small" - self.client = OpenAI(base_url=config["backend_url"]) + # Get the OpenAI API key from environment variable + openai_api_key = os.getenv("OPENAI_API_KEY") + self.client = OpenAI(base_url=config["backend_url"], api_key=openai_api_key) self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.situation_collection = self.chroma_client.create_collection(name=name) diff --git a/tradingagents/dataflows/openai.py b/tradingagents/dataflows/openai.py index 1aa43871..6174bbe0 100644 --- a/tradingagents/dataflows/openai.py +++ b/tradingagents/dataflows/openai.py @@ -1,3 +1,4 @@ +import os from openai import OpenAI from .config import get_config @@ -15,7 +16,9 @@ def get_stock_news_openai(query, start_date, end_date): str: 模型的文字回應。 """ config = get_config() - client = OpenAI(base_url=config["backend_url"]) + # Get the OpenAI API key from environment variable + openai_api_key = os.getenv("OPENAI_API_KEY") + client = OpenAI(base_url=config["backend_url"], api_key=openai_api_key) response = client.responses.create( model=config["quick_think_llm"], @@ -61,7 +64,9 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5): str: 模型的文字回應。 """ config = get_config() - client = OpenAI(base_url=config["backend_url"]) + # Get the OpenAI API key from environment variable + openai_api_key = os.getenv("OPENAI_API_KEY") + client = OpenAI(base_url=config["backend_url"], api_key=openai_api_key) response = client.responses.create( model=config["quick_think_llm"], @@ -106,7 +111,9 @@ def get_fundamentals_openai(ticker, curr_date): str: 模型的文字回應。 """ config = get_config() - client = OpenAI(base_url=config["backend_url"]) + # Get the OpenAI API key from environment variable + openai_api_key = os.getenv("OPENAI_API_KEY") + client = OpenAI(base_url=config["backend_url"], api_key=openai_api_key) response = client.responses.create( model=config["quick_think_llm"], diff --git a/tradingagents/graph/conditional_logic.py b/tradingagents/graph/conditional_logic.py index 51ffd226..6e647a55 100644 --- a/tradingagents/graph/conditional_logic.py +++ b/tradingagents/graph/conditional_logic.py @@ -108,7 +108,8 @@ class ConditionalLogic: state["investment_debate_state"]["count"] >= 2 * self.max_debate_rounds ): return "Research Manager" - if state["investment_debate_state"]["current_response"].startswith("Bull"): + # 檢查中文前綴(因為研究員使用中文格式化響應) + if state["investment_debate_state"]["current_response"].startswith("看漲"): return "Bear Researcher" return "Bull Researcher"