Fix memory embeddings for google with aysncio event loop
This commit is contained in:
parent
c65d764908
commit
c13bc042c8
|
|
@ -2,6 +2,9 @@ import chromadb
|
|||
from chromadb.config import Settings
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import asyncio
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
class FinancialSituationMemory:
|
||||
|
|
@ -12,9 +15,20 @@ class FinancialSituationMemory:
|
|||
self.embedding = "text-embedding-3-small"
|
||||
self.client = OpenAI(base_url=config["backend_url"])
|
||||
self.embedding_model = None
|
||||
elif self.provider == "google":
|
||||
import asyncio
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
google_api_key = os.getenv("GOOGLE_API_KEY")
|
||||
self.embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=google_api_key)
|
||||
self.client = None
|
||||
elif self.provider == "anthropic":
|
||||
self.embedding_model = None
|
||||
self.client = None
|
||||
else:
|
||||
# Use a local embedding model for non-OpenAI providers
|
||||
from sentence_transformers import SentenceTransformer
|
||||
# Use a local embedding model for other non-OpenAI providers
|
||||
self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
self.client = None
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
|
|
@ -26,6 +40,10 @@ class FinancialSituationMemory:
|
|||
model=self.embedding, input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
elif self.provider == "google":
|
||||
return self.embedding_model.embed_query(text)
|
||||
elif self.provider == "anthropic":
|
||||
raise NotImplementedError("Memory features are currently not supported for Anthropic provider. Please use OpenAI or Google for memory-enabled workflows.")
|
||||
else:
|
||||
# Use local embedding model
|
||||
return self.embedding_model.encode(text).tolist()
|
||||
|
|
@ -54,7 +72,7 @@ class FinancialSituationMemory:
|
|||
)
|
||||
|
||||
def get_memories(self, current_situation, n_matches=1):
|
||||
"""Find matching recommendations using OpenAI embeddings"""
|
||||
"""Find matching recommendations using provider-appropriate embeddings"""
|
||||
query_embedding = self.get_embedding(current_situation)
|
||||
|
||||
results = self.situation_collection.query(
|
||||
|
|
|
|||
Loading…
Reference in New Issue