68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
import logging
|
|
import chromadb
|
|
from chromadb.config import Settings
|
|
from openai import OpenAI
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FinancialSituationMemory:
|
|
def __init__(self, name, config):
|
|
if config["backend_url"] == "http://localhost:11434/v1":
|
|
self.embedding = "nomic-embed-text"
|
|
else:
|
|
self.embedding = "text-embedding-3-small"
|
|
self.client = OpenAI(base_url=config["backend_url"])
|
|
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
|
self.situation_collection = self.chroma_client.get_or_create_collection(name=name)
|
|
|
|
def get_embedding(self, text):
|
|
|
|
response = self.client.embeddings.create(
|
|
model=self.embedding, input=text
|
|
)
|
|
return response.data[0].embedding
|
|
|
|
def add_situations(self, situations_and_advice):
|
|
|
|
situations = []
|
|
advice = []
|
|
ids = []
|
|
embeddings = []
|
|
|
|
offset = self.situation_collection.count()
|
|
|
|
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
|
situations.append(situation)
|
|
advice.append(recommendation)
|
|
ids.append(str(offset + i))
|
|
embeddings.append(self.get_embedding(situation))
|
|
|
|
self.situation_collection.add(
|
|
documents=situations,
|
|
metadatas=[{"recommendation": rec} for rec in advice],
|
|
embeddings=embeddings,
|
|
ids=ids,
|
|
)
|
|
|
|
def get_memories(self, current_situation, n_matches=1):
|
|
query_embedding = self.get_embedding(current_situation)
|
|
|
|
results = self.situation_collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=n_matches,
|
|
include=["metadatas", "documents", "distances"],
|
|
)
|
|
|
|
matched_results = []
|
|
for i in range(len(results["documents"][0])):
|
|
matched_results.append(
|
|
{
|
|
"matched_situation": results["documents"][0][i],
|
|
"recommendation": results["metadatas"][0][i]["recommendation"],
|
|
"similarity_score": 1 - results["distances"][0][i],
|
|
}
|
|
)
|
|
|
|
return matched_results
|