Fix memory embeddings for google with aysncio event loop

This commit is contained in:
sdk451 2025-07-24 23:01:27 +10:00
parent c65d764908
commit c13bc042c8
1 changed files with 21 additions and 3 deletions

View File

@ -2,6 +2,9 @@ import chromadb
from chromadb.config import Settings
from openai import OpenAI
import os
import asyncio
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from sentence_transformers import SentenceTransformer
class FinancialSituationMemory:
@ -12,9 +15,20 @@ class FinancialSituationMemory:
self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"])
self.embedding_model = None
elif self.provider == "google":
import asyncio
try:
asyncio.get_running_loop()
except RuntimeError:
asyncio.set_event_loop(asyncio.new_event_loop())
google_api_key = os.getenv("GOOGLE_API_KEY")
self.embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=google_api_key)
self.client = None
elif self.provider == "anthropic":
self.embedding_model = None
self.client = None
else:
# Use a local embedding model for non-OpenAI providers
from sentence_transformers import SentenceTransformer
# Use a local embedding model for other non-OpenAI providers
self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
self.client = None
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
@ -26,6 +40,10 @@ class FinancialSituationMemory:
model=self.embedding, input=text
)
return response.data[0].embedding
elif self.provider == "google":
return self.embedding_model.embed_query(text)
elif self.provider == "anthropic":
raise NotImplementedError("Memory features are currently not supported for Anthropic provider. Please use OpenAI or Google for memory-enabled workflows.")
else:
# Use local embedding model
return self.embedding_model.encode(text).tolist()
@ -54,7 +72,7 @@ class FinancialSituationMemory:
)
def get_memories(self, current_situation, n_matches=1):
"""Find matching recommendations using OpenAI embeddings"""
"""Find matching recommendations using provider-appropriate embeddings"""
query_embedding = self.get_embedding(current_situation)
results = self.situation_collection.query(