diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 49aff9cd..f92d38ad 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -2,6 +2,9 @@ import chromadb from chromadb.config import Settings from openai import OpenAI import os +import asyncio +from langchain_google_genai import GoogleGenerativeAIEmbeddings +from sentence_transformers import SentenceTransformer class FinancialSituationMemory: @@ -12,9 +15,20 @@ class FinancialSituationMemory: self.embedding = "text-embedding-3-small" self.client = OpenAI(base_url=config["backend_url"]) self.embedding_model = None + elif self.provider == "google": + import asyncio + try: + asyncio.get_running_loop() + except RuntimeError: + asyncio.set_event_loop(asyncio.new_event_loop()) + google_api_key = os.getenv("GOOGLE_API_KEY") + self.embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=google_api_key) + self.client = None + elif self.provider == "anthropic": + self.embedding_model = None + self.client = None else: - # Use a local embedding model for non-OpenAI providers - from sentence_transformers import SentenceTransformer + # Use a local embedding model for other non-OpenAI providers self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2") self.client = None self.chroma_client = chromadb.Client(Settings(allow_reset=True)) @@ -26,6 +40,10 @@ class FinancialSituationMemory: model=self.embedding, input=text ) return response.data[0].embedding + elif self.provider == "google": + return self.embedding_model.embed_query(text) + elif self.provider == "anthropic": + raise NotImplementedError("Memory features are currently not supported for Anthropic provider. Please use OpenAI or Google for memory-enabled workflows.") else: # Use local embedding model return self.embedding_model.encode(text).tolist() @@ -54,7 +72,7 @@ class FinancialSituationMemory: ) def get_memories(self, current_situation, n_matches=1): - """Find matching recommendations using OpenAI embeddings""" + """Find matching recommendations using provider-appropriate embeddings""" query_embedding = self.get_embedding(current_situation) results = self.situation_collection.query(