allow gemini
This commit is contained in:
parent
a2ba20a75c
commit
0776904d2e
|
|
@ -9,7 +9,7 @@ class FinancialSituationMemory:
|
|||
self.embedding = "nomic-embed-text"
|
||||
else:
|
||||
self.embedding = "text-embedding-3-small"
|
||||
self.client = OpenAI(base_url=config["backend_url"])
|
||||
self.client = OpenAI()
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,42 @@
|
|||
import tradingagents.default_config as default_config
|
||||
from typing import Dict, Optional
|
||||
import os
|
||||
|
||||
# Use default config but allow it to be overridden
|
||||
_config: Optional[Dict] = None
|
||||
DATA_DIR: Optional[str] = None
|
||||
|
||||
|
||||
def validate_api_keys(config: Dict) -> None:
|
||||
"""Validate that required API keys are present based on LLM provider."""
|
||||
llm_provider = config.get("llm_provider", "openai").lower()
|
||||
|
||||
# Always require OpenAI API key since many functions use it
|
||||
if not config.get("OPENAI_API_KEY"):
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY is required in environment variables. "
|
||||
"Many functions and tools require OpenAI API access."
|
||||
)
|
||||
|
||||
# Provider-specific validations
|
||||
if llm_provider == "openai":
|
||||
if not config.get("OPENAI_API_KEY"):
|
||||
raise ValueError("OPENAI_API_KEY is required for OpenAI provider")
|
||||
# elif llm_provider == "anthropic":
|
||||
# if not config.get("anthropic_api_key"):
|
||||
# raise ValueError("ANTHROPIC_API_KEY is required for Anthropic provider")
|
||||
elif llm_provider == "google":
|
||||
if not config.get("GOOGLE_API_KEY"):
|
||||
raise ValueError("GOOGLE_API_KEY is required for Google provider")
|
||||
|
||||
|
||||
def initialize_config():
|
||||
"""Initialize the configuration with default values."""
|
||||
global _config, DATA_DIR
|
||||
if _config is None:
|
||||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
DATA_DIR = _config["data_dir"]
|
||||
validate_api_keys(_config)
|
||||
|
||||
|
||||
def set_config(config: Dict):
|
||||
|
|
@ -21,6 +46,7 @@ def set_config(config: Dict):
|
|||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
_config.update(config)
|
||||
DATA_DIR = _config["data_dir"]
|
||||
validate_api_keys(_config)
|
||||
|
||||
|
||||
def get_config() -> Dict:
|
||||
|
|
|
|||
|
|
@ -704,10 +704,11 @@ def get_YFin_data(
|
|||
|
||||
def get_stock_news_openai(ticker, curr_date):
|
||||
config = get_config()
|
||||
client = OpenAI(base_url=config["backend_url"])
|
||||
client = OpenAI(api_key=config["OPENAI_API_KEY"])
|
||||
|
||||
|
||||
response = client.responses.create(
|
||||
model=config["quick_think_llm"],
|
||||
model= 'gpt-4o-mini', # TODO: change to config["quick_think_llm"],
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -739,10 +740,10 @@ def get_stock_news_openai(ticker, curr_date):
|
|||
|
||||
def get_global_news_openai(curr_date):
|
||||
config = get_config()
|
||||
client = OpenAI(base_url=config["backend_url"])
|
||||
client = OpenAI()
|
||||
|
||||
response = client.responses.create(
|
||||
model=config["quick_think_llm"],
|
||||
model= 'gpt-4o-mini', # TODO: change to config["quick_think_llm"],
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -774,10 +775,10 @@ def get_global_news_openai(curr_date):
|
|||
|
||||
def get_fundamentals_openai(ticker, curr_date):
|
||||
config = get_config()
|
||||
client = OpenAI(base_url=config["backend_url"])
|
||||
client = OpenAI()
|
||||
|
||||
response = client.responses.create(
|
||||
model=config["quick_think_llm"],
|
||||
model= 'gpt-4o-mini', # TODO: change to config["quick_think_llm"],
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
|
|
|
|||
|
|
@ -3,11 +3,16 @@ import os
|
|||
DEFAULT_CONFIG = {
|
||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
||||
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data", # TODO: change to your own data directory
|
||||
"data_cache_dir": os.path.join(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"dataflows/data_cache",
|
||||
),
|
||||
# API Keys - load from environment variables
|
||||
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
|
||||
"FINNHUB_API_KEY": os.getenv("FINNHUB_API_KEY"),
|
||||
"GOOGLE_API_KEY": os.getenv("GOOGLE_API_KEY"),
|
||||
"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY"),
|
||||
# LLM settings
|
||||
"llm_provider": "openai",
|
||||
"deep_think_llm": "o4-mini",
|
||||
|
|
|
|||
Loading…
Reference in New Issue