diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index a1934bd8..7ec1ac3b 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,20 +1,17 @@ import chromadb from chromadb.config import Settings from openai import OpenAI -import numpy as np class FinancialSituationMemory: def __init__(self, name): self.client = OpenAI() self.chroma_client = chromadb.Client(Settings(allow_reset=True)) - self.situation_collection = self.chroma_client.create_collection(name=name) + self.situation_collection = self.chroma_client.get_or_create_collection(name=name) def get_embedding(self, text): """Get OpenAI embedding for a text""" - response = self.client.embeddings.create( - model="text-embedding-ada-002", input=text - ) + response = self.client.embeddings.create(model="text-embedding-ada-002", input=text) return response.data[0].embedding def add_situations(self, situations_and_advice): @@ -92,7 +89,7 @@ if __name__ == "__main__": # Example query current_situation = """ - Market showing increased volatility in tech sector, with institutional investors + Market showing increased volatility in tech sector, with institutional investors reducing positions and rising interest rates affecting growth stock valuations """