From c65d764908c4d4bdf4ceacb351988582e6d0b8eb Mon Sep 17 00:00:00 2001 From: sdk451 Date: Wed, 23 Jul 2025 11:49:09 +1000 Subject: [PATCH] Update solution to use provider aware embedding logic if not using openai as the LLM provider --- requirements.txt | 1 + tradingagents/agents/utils/memory.py | 29 ++++++++++++++++++---------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/requirements.txt b/requirements.txt index d8de9fe8..a8826d74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,4 @@ questionary langchain_anthropic langchain-google-genai python-dotenv +sentence-transformers diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 69b8ab8c..49aff9cd 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,25 +1,34 @@ import chromadb from chromadb.config import Settings from openai import OpenAI +import os class FinancialSituationMemory: def __init__(self, name, config): - if config["backend_url"] == "http://localhost:11434/v1": - self.embedding = "nomic-embed-text" - else: + self.config = config + self.provider = config.get("llm_provider", "openai").lower() + if self.provider == "openai": self.embedding = "text-embedding-3-small" - self.client = OpenAI(base_url=config["backend_url"]) + self.client = OpenAI(base_url=config["backend_url"]) + self.embedding_model = None + else: + # Use a local embedding model for non-OpenAI providers + from sentence_transformers import SentenceTransformer + self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2") + self.client = None self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.situation_collection = self.chroma_client.create_collection(name=name) def get_embedding(self, text): - """Get OpenAI embedding for a text""" - - response = self.client.embeddings.create( - model=self.embedding, input=text - ) - return response.data[0].embedding + if self.provider == "openai": + response = self.client.embeddings.create( + model=self.embedding, input=text + ) + return response.data[0].embedding + else: + # Use local embedding model + return self.embedding_model.encode(text).tolist() def add_situations(self, situations_and_advice): """Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""