Fix memory embeddings for google with aysncio event loop
This commit is contained in:
parent
c65d764908
commit
c13bc042c8
|
|
@ -2,6 +2,9 @@ import chromadb
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
|
||||||
class FinancialSituationMemory:
|
class FinancialSituationMemory:
|
||||||
|
|
@ -12,9 +15,20 @@ class FinancialSituationMemory:
|
||||||
self.embedding = "text-embedding-3-small"
|
self.embedding = "text-embedding-3-small"
|
||||||
self.client = OpenAI(base_url=config["backend_url"])
|
self.client = OpenAI(base_url=config["backend_url"])
|
||||||
self.embedding_model = None
|
self.embedding_model = None
|
||||||
|
elif self.provider == "google":
|
||||||
|
import asyncio
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||||
|
google_api_key = os.getenv("GOOGLE_API_KEY")
|
||||||
|
self.embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=google_api_key)
|
||||||
|
self.client = None
|
||||||
|
elif self.provider == "anthropic":
|
||||||
|
self.embedding_model = None
|
||||||
|
self.client = None
|
||||||
else:
|
else:
|
||||||
# Use a local embedding model for non-OpenAI providers
|
# Use a local embedding model for other non-OpenAI providers
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
self.client = None
|
self.client = None
|
||||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||||
|
|
@ -26,6 +40,10 @@ class FinancialSituationMemory:
|
||||||
model=self.embedding, input=text
|
model=self.embedding, input=text
|
||||||
)
|
)
|
||||||
return response.data[0].embedding
|
return response.data[0].embedding
|
||||||
|
elif self.provider == "google":
|
||||||
|
return self.embedding_model.embed_query(text)
|
||||||
|
elif self.provider == "anthropic":
|
||||||
|
raise NotImplementedError("Memory features are currently not supported for Anthropic provider. Please use OpenAI or Google for memory-enabled workflows.")
|
||||||
else:
|
else:
|
||||||
# Use local embedding model
|
# Use local embedding model
|
||||||
return self.embedding_model.encode(text).tolist()
|
return self.embedding_model.encode(text).tolist()
|
||||||
|
|
@ -54,7 +72,7 @@ class FinancialSituationMemory:
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_memories(self, current_situation, n_matches=1):
|
def get_memories(self, current_situation, n_matches=1):
|
||||||
"""Find matching recommendations using OpenAI embeddings"""
|
"""Find matching recommendations using provider-appropriate embeddings"""
|
||||||
query_embedding = self.get_embedding(current_situation)
|
query_embedding = self.get_embedding(current_situation)
|
||||||
|
|
||||||
results = self.situation_collection.query(
|
results = self.situation_collection.query(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue