import chromadb from chromadb.config import Settings import os class FinancialSituationMemory: def __init__(self, name, config): # Use local embeddings for all providers - no external API dependency self.use_local_embeddings = config.get("use_local_embeddings", True) if self.use_local_embeddings: try: from sentence_transformers import SentenceTransformer # Use a good general-purpose model for financial text self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') self.embedding_type = "local" print(f"✅ Using local embeddings with sentence-transformers") except ImportError: print("⚠️ sentence-transformers not found. Install with: pip install sentence-transformers") print("Falling back to ChromaDB's default embeddings...") self.embedding_model = None self.embedding_type = "chromadb_default" else: # Legacy API-based embeddings (kept for backward compatibility) from openai import OpenAI if config["backend_url"] == "http://localhost:11434/v1": self.embedding = "nomic-embed-text" self.client = OpenAI(base_url=config["backend_url"]) else: self.embedding = "text-embedding-3-small" if "openrouter.ai" in config["backend_url"]: openai_api_key = os.getenv("OPENAI_API_KEY") if not openai_api_key: raise ValueError("❌ OPENAI_API_KEY required for API-based embeddings with OpenRouter") self.client = OpenAI(api_key=openai_api_key) else: api_key = None if config.get("llm_provider") == "openai": api_key = os.getenv("OPENAI_API_KEY") self.client = OpenAI(base_url=config["backend_url"], api_key=api_key) self.embedding_type = "api" self.chroma_client = chromadb.Client(Settings(allow_reset=True)) # Create collection with or without custom embedding function if self.embedding_type == "chromadb_default": # Let ChromaDB handle embeddings with its default function self.situation_collection = self.chroma_client.create_collection(name=name) else: # We'll handle embeddings ourselves self.situation_collection = self.chroma_client.create_collection( name=name, metadata={"hnsw:space": "cosine"} # Use cosine similarity ) def get_embedding(self, text): """Get embedding for a text using local or API-based models""" try: if self.embedding_type == "local": # Use local sentence-transformers model embedding = self.embedding_model.encode(text, convert_to_tensor=False) return embedding.tolist() if hasattr(embedding, 'tolist') else embedding elif self.embedding_type == "chromadb_default": # ChromaDB will handle embeddings automatically, return None return None else: # API-based embeddings response = self.client.embeddings.create( model=self.embedding, input=text ) if hasattr(response, 'data') and len(response.data) > 0: return response.data[0].embedding else: raise ValueError(f"Unexpected response format from embeddings API: {type(response)}") except Exception as e: raise RuntimeError(f"Failed to get embedding for text: {str(e)}") def add_situations(self, situations_and_advice): """Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)""" situations = [] advice = [] ids = [] embeddings = [] offset = self.situation_collection.count() for i, (situation, recommendation) in enumerate(situations_and_advice): situations.append(situation) advice.append(recommendation) ids.append(str(offset + i)) # Only compute embeddings if not using ChromaDB's default if self.embedding_type != "chromadb_default": embeddings.append(self.get_embedding(situation)) # Add to collection with or without custom embeddings if self.embedding_type == "chromadb_default": # Let ChromaDB compute embeddings automatically self.situation_collection.add( documents=situations, metadatas=[{"recommendation": rec} for rec in advice], ids=ids, ) else: # Use our custom embeddings self.situation_collection.add( documents=situations, metadatas=[{"recommendation": rec} for rec in advice], embeddings=embeddings, ids=ids, ) def get_memories(self, current_situation, n_matches=1): """Find matching recommendations using local or API-based embeddings""" if self.embedding_type == "chromadb_default": # Use ChromaDB's built-in embeddings - query with text directly results = self.situation_collection.query( query_texts=[current_situation], n_results=n_matches, include=["metadatas", "documents", "distances"], ) else: # Use our custom embeddings query_embedding = self.get_embedding(current_situation) results = self.situation_collection.query( query_embeddings=[query_embedding], n_results=n_matches, include=["metadatas", "documents", "distances"], ) matched_results = [] for i in range(len(results["documents"][0])): matched_results.append( { "matched_situation": results["documents"][0][i], "recommendation": results["metadatas"][0][i]["recommendation"], "similarity_score": 1 - results["distances"][0][i], } ) return matched_results if __name__ == "__main__": # Example usage matcher = FinancialSituationMemory() # Example data example_data = [ ( "High inflation rate with rising interest rates and declining consumer spending", "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.", ), ( "Tech sector showing high volatility with increasing institutional selling pressure", "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.", ), ( "Strong dollar affecting emerging markets with increasing forex volatility", "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.", ), ( "Market showing signs of sector rotation with rising yields", "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.", ), ] # Add the example situations and recommendations matcher.add_situations(example_data) # Example query current_situation = """ Market showing increased volatility in tech sector, with institutional investors reducing positions and rising interest rates affecting growth stock valuations """ try: recommendations = matcher.get_memories(current_situation, n_matches=2) for i, rec in enumerate(recommendations, 1): print(f"\nMatch {i}:") print(f"Similarity Score: {rec['similarity_score']:.2f}") print(f"Matched Situation: {rec['matched_situation']}") print(f"Recommendation: {rec['recommendation']}") except Exception as e: print(f"Error during recommendation: {str(e)}")