This commit is contained in:
MarkLo 2025-11-16 05:28:16 +08:00
parent d522d3ec43
commit 82829741b9
3 changed files with 16 additions and 5 deletions

View File

@ -1,3 +1,4 @@
import os
import chromadb
from chromadb.config import Settings
from openai import OpenAI
@ -9,7 +10,9 @@ class FinancialSituationMemory:
self.embedding = "nomic-embed-text"
else:
self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"])
# Get the OpenAI API key from environment variable
openai_api_key = os.getenv("OPENAI_API_KEY")
self.client = OpenAI(base_url=config["backend_url"], api_key=openai_api_key)
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)

View File

@ -1,3 +1,4 @@
import os
from openai import OpenAI
from .config import get_config
@ -15,7 +16,9 @@ def get_stock_news_openai(query, start_date, end_date):
str: 模型的文字回應
"""
config = get_config()
client = OpenAI(base_url=config["backend_url"])
# Get the OpenAI API key from environment variable
openai_api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(base_url=config["backend_url"], api_key=openai_api_key)
response = client.responses.create(
model=config["quick_think_llm"],
@ -61,7 +64,9 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5):
str: 模型的文字回應
"""
config = get_config()
client = OpenAI(base_url=config["backend_url"])
# Get the OpenAI API key from environment variable
openai_api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(base_url=config["backend_url"], api_key=openai_api_key)
response = client.responses.create(
model=config["quick_think_llm"],
@ -106,7 +111,9 @@ def get_fundamentals_openai(ticker, curr_date):
str: 模型的文字回應
"""
config = get_config()
client = OpenAI(base_url=config["backend_url"])
# Get the OpenAI API key from environment variable
openai_api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(base_url=config["backend_url"], api_key=openai_api_key)
response = client.responses.create(
model=config["quick_think_llm"],

View File

@ -108,7 +108,8 @@ class ConditionalLogic:
state["investment_debate_state"]["count"] >= 2 * self.max_debate_rounds
):
return "Research Manager"
if state["investment_debate_state"]["current_response"].startswith("Bull"):
# 檢查中文前綴(因為研究員使用中文格式化響應)
if state["investment_debate_state"]["current_response"].startswith("看漲"):
return "Bear Researcher"
return "Bull Researcher"