allow gemini
This commit is contained in:
parent
a2ba20a75c
commit
0776904d2e
|
|
@ -9,7 +9,7 @@ class FinancialSituationMemory:
|
||||||
self.embedding = "nomic-embed-text"
|
self.embedding = "nomic-embed-text"
|
||||||
else:
|
else:
|
||||||
self.embedding = "text-embedding-3-small"
|
self.embedding = "text-embedding-3-small"
|
||||||
self.client = OpenAI(base_url=config["backend_url"])
|
self.client = OpenAI()
|
||||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,42 @@
|
||||||
import tradingagents.default_config as default_config
|
import tradingagents.default_config as default_config
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
import os
|
||||||
|
|
||||||
# Use default config but allow it to be overridden
|
# Use default config but allow it to be overridden
|
||||||
_config: Optional[Dict] = None
|
_config: Optional[Dict] = None
|
||||||
DATA_DIR: Optional[str] = None
|
DATA_DIR: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_api_keys(config: Dict) -> None:
|
||||||
|
"""Validate that required API keys are present based on LLM provider."""
|
||||||
|
llm_provider = config.get("llm_provider", "openai").lower()
|
||||||
|
|
||||||
|
# Always require OpenAI API key since many functions use it
|
||||||
|
if not config.get("OPENAI_API_KEY"):
|
||||||
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY is required in environment variables. "
|
||||||
|
"Many functions and tools require OpenAI API access."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Provider-specific validations
|
||||||
|
if llm_provider == "openai":
|
||||||
|
if not config.get("OPENAI_API_KEY"):
|
||||||
|
raise ValueError("OPENAI_API_KEY is required for OpenAI provider")
|
||||||
|
# elif llm_provider == "anthropic":
|
||||||
|
# if not config.get("anthropic_api_key"):
|
||||||
|
# raise ValueError("ANTHROPIC_API_KEY is required for Anthropic provider")
|
||||||
|
elif llm_provider == "google":
|
||||||
|
if not config.get("GOOGLE_API_KEY"):
|
||||||
|
raise ValueError("GOOGLE_API_KEY is required for Google provider")
|
||||||
|
|
||||||
|
|
||||||
def initialize_config():
|
def initialize_config():
|
||||||
"""Initialize the configuration with default values."""
|
"""Initialize the configuration with default values."""
|
||||||
global _config, DATA_DIR
|
global _config, DATA_DIR
|
||||||
if _config is None:
|
if _config is None:
|
||||||
_config = default_config.DEFAULT_CONFIG.copy()
|
_config = default_config.DEFAULT_CONFIG.copy()
|
||||||
DATA_DIR = _config["data_dir"]
|
DATA_DIR = _config["data_dir"]
|
||||||
|
validate_api_keys(_config)
|
||||||
|
|
||||||
|
|
||||||
def set_config(config: Dict):
|
def set_config(config: Dict):
|
||||||
|
|
@ -21,6 +46,7 @@ def set_config(config: Dict):
|
||||||
_config = default_config.DEFAULT_CONFIG.copy()
|
_config = default_config.DEFAULT_CONFIG.copy()
|
||||||
_config.update(config)
|
_config.update(config)
|
||||||
DATA_DIR = _config["data_dir"]
|
DATA_DIR = _config["data_dir"]
|
||||||
|
validate_api_keys(_config)
|
||||||
|
|
||||||
|
|
||||||
def get_config() -> Dict:
|
def get_config() -> Dict:
|
||||||
|
|
|
||||||
|
|
@ -704,10 +704,11 @@ def get_YFin_data(
|
||||||
|
|
||||||
def get_stock_news_openai(ticker, curr_date):
|
def get_stock_news_openai(ticker, curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
client = OpenAI(base_url=config["backend_url"])
|
client = OpenAI(api_key=config["OPENAI_API_KEY"])
|
||||||
|
|
||||||
|
|
||||||
response = client.responses.create(
|
response = client.responses.create(
|
||||||
model=config["quick_think_llm"],
|
model= 'gpt-4o-mini', # TODO: change to config["quick_think_llm"],
|
||||||
input=[
|
input=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
|
@ -739,10 +740,10 @@ def get_stock_news_openai(ticker, curr_date):
|
||||||
|
|
||||||
def get_global_news_openai(curr_date):
|
def get_global_news_openai(curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
client = OpenAI(base_url=config["backend_url"])
|
client = OpenAI()
|
||||||
|
|
||||||
response = client.responses.create(
|
response = client.responses.create(
|
||||||
model=config["quick_think_llm"],
|
model= 'gpt-4o-mini', # TODO: change to config["quick_think_llm"],
|
||||||
input=[
|
input=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
|
@ -774,10 +775,10 @@ def get_global_news_openai(curr_date):
|
||||||
|
|
||||||
def get_fundamentals_openai(ticker, curr_date):
|
def get_fundamentals_openai(ticker, curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
client = OpenAI(base_url=config["backend_url"])
|
client = OpenAI()
|
||||||
|
|
||||||
response = client.responses.create(
|
response = client.responses.create(
|
||||||
model=config["quick_think_llm"],
|
model= 'gpt-4o-mini', # TODO: change to config["quick_think_llm"],
|
||||||
input=[
|
input=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,16 @@ import os
|
||||||
DEFAULT_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||||
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data", # TODO: change to your own data directory
|
||||||
"data_cache_dir": os.path.join(
|
"data_cache_dir": os.path.join(
|
||||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||||
"dataflows/data_cache",
|
"dataflows/data_cache",
|
||||||
),
|
),
|
||||||
|
# API Keys - load from environment variables
|
||||||
|
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
|
||||||
|
"FINNHUB_API_KEY": os.getenv("FINNHUB_API_KEY"),
|
||||||
|
"GOOGLE_API_KEY": os.getenv("GOOGLE_API_KEY"),
|
||||||
|
"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY"),
|
||||||
# LLM settings
|
# LLM settings
|
||||||
"llm_provider": "openai",
|
"llm_provider": "openai",
|
||||||
"deep_think_llm": "o4-mini",
|
"deep_think_llm": "o4-mini",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue