allow gemini

This commit is contained in:
Flora Xu 2025-07-05 17:26:32 -07:00
parent a2ba20a75c
commit 0776904d2e
4 changed files with 40 additions and 8 deletions

View File

@ -9,7 +9,7 @@ class FinancialSituationMemory:
self.embedding = "nomic-embed-text"
else:
self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"])
self.client = OpenAI()
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)

View File

@ -1,17 +1,42 @@
import tradingagents.default_config as default_config
from typing import Dict, Optional
import os
# Use default config but allow it to be overridden
_config: Optional[Dict] = None
DATA_DIR: Optional[str] = None
def validate_api_keys(config: Dict) -> None:
"""Validate that required API keys are present based on LLM provider."""
llm_provider = config.get("llm_provider", "openai").lower()
# Always require OpenAI API key since many functions use it
if not config.get("OPENAI_API_KEY"):
raise ValueError(
"OPENAI_API_KEY is required in environment variables. "
"Many functions and tools require OpenAI API access."
)
# Provider-specific validations
if llm_provider == "openai":
if not config.get("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY is required for OpenAI provider")
# elif llm_provider == "anthropic":
# if not config.get("anthropic_api_key"):
# raise ValueError("ANTHROPIC_API_KEY is required for Anthropic provider")
elif llm_provider == "google":
if not config.get("GOOGLE_API_KEY"):
raise ValueError("GOOGLE_API_KEY is required for Google provider")
def initialize_config():
"""Initialize the configuration with default values."""
global _config, DATA_DIR
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
DATA_DIR = _config["data_dir"]
validate_api_keys(_config)
def set_config(config: Dict):
@ -21,6 +46,7 @@ def set_config(config: Dict):
_config = default_config.DEFAULT_CONFIG.copy()
_config.update(config)
DATA_DIR = _config["data_dir"]
validate_api_keys(_config)
def get_config() -> Dict:

View File

@ -704,10 +704,11 @@ def get_YFin_data(
def get_stock_news_openai(ticker, curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
client = OpenAI(api_key=config["OPENAI_API_KEY"])
response = client.responses.create(
model=config["quick_think_llm"],
model= 'gpt-4o-mini', # TODO: change to config["quick_think_llm"],
input=[
{
"role": "system",
@ -739,10 +740,10 @@ def get_stock_news_openai(ticker, curr_date):
def get_global_news_openai(curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
client = OpenAI()
response = client.responses.create(
model=config["quick_think_llm"],
model= 'gpt-4o-mini', # TODO: change to config["quick_think_llm"],
input=[
{
"role": "system",
@ -774,10 +775,10 @@ def get_global_news_openai(curr_date):
def get_fundamentals_openai(ticker, curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
client = OpenAI()
response = client.responses.create(
model=config["quick_think_llm"],
model= 'gpt-4o-mini', # TODO: change to config["quick_think_llm"],
input=[
{
"role": "system",

View File

@ -3,11 +3,16 @@ import os
DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data", # TODO: change to your own data directory
"data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
),
# API Keys - load from environment variables
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
"FINNHUB_API_KEY": os.getenv("FINNHUB_API_KEY"),
"GOOGLE_API_KEY": os.getenv("GOOGLE_API_KEY"),
"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY"),
# LLM settings
"llm_provider": "openai",
"deep_think_llm": "o4-mini",