import os from typing import Any, Dict, List, Optional, Tuple import chromadb from openai import OpenAI from tradingagents.utils.logger import get_logger logger = get_logger(__name__) class FinancialSituationMemory: def __init__(self, name, config): # Determine embedding backend URL # For Ollama, use the Ollama endpoint; otherwise default to OpenAI for embeddings if config.get("backend_url") == "http://localhost:11434/v1": self.embedding_backend = "http://localhost:11434/v1" self.embedding = "nomic-embed-text" else: # Always use OpenAI for embeddings, regardless of LLM provider self.embedding_backend = "https://api.openai.com/v1" self.embedding = "text-embedding-3-small" self.client = OpenAI(api_key=config.validate_key("openai_api_key", "OpenAI")) # Use persistent storage in project directory persist_directory = os.path.join(config.get("project_dir", "."), "memory_db") os.makedirs(persist_directory, exist_ok=True) self.chroma_client = chromadb.PersistentClient(path=persist_directory) # Get or create collection try: self.situation_collection = self.chroma_client.get_collection(name=name) except Exception: self.situation_collection = self.chroma_client.create_collection(name=name) def get_embedding(self, text): """Get OpenAI embedding for a text""" response = self.client.embeddings.create(model=self.embedding, input=text) return response.data[0].embedding def _batch_add( self, documents: List[str], metadatas: List[Dict[str, Any]], embeddings: List[List[float]], ids: List[str] = None, ): """Internal helper to batch add documents to ChromaDB.""" if not documents: return if ids is None: offset = self.situation_collection.count() ids = [str(offset + i) for i in range(len(documents))] self.situation_collection.add( documents=documents, metadatas=metadatas, embeddings=embeddings, ids=ids, ) def add_situations(self, situations_and_advice): """Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)""" situations = [] metadatas = [] embeddings = [] for situation, recommendation in situations_and_advice: situations.append(situation) metadatas.append({"recommendation": recommendation}) embeddings.append(self.get_embedding(situation)) self._batch_add(situations, metadatas, embeddings) def add_situations_with_metadata( self, situations_and_outcomes: List[Tuple[str, str, Dict[str, Any]]] ): """ Add financial situations with enhanced metadata for learning system. Args: situations_and_outcomes: List of tuples (situation_text, recommendation, metadata) where metadata contains: - ticker: Stock symbol - analysis_date: Date of analysis (YYYY-MM-DD) - days_before_move: How many days before the major move (7 or 30) - move_pct: Percentage move that occurred - move_direction: "up" or "down" - agent_recommendation: What the agent recommended - was_correct: Boolean, whether recommendation matched outcome - structured_signals: Dict of signal features (optional) - unusual_volume: bool - analyst_sentiment: str (bullish/bearish/neutral) - news_sentiment: str (positive/negative/neutral) - short_interest: str (high/medium/low) - insider_activity: str (buying/selling/none) - etc. """ situations = [] metadatas = [] embeddings = [] for situation, recommendation, metadata in situations_and_outcomes: situations.append(situation) embeddings.append(self.get_embedding(situation)) # Merge recommendation with metadata full_metadata = {"recommendation": recommendation} full_metadata.update(metadata) # Ensure all metadata values are strings, numbers, or booleans for ChromaDB full_metadata = self._sanitize_metadata(full_metadata) metadatas.append(full_metadata) self._batch_add(situations, metadatas, embeddings) def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: """ Sanitize metadata for ChromaDB compatibility. ChromaDB requires metadata values to be str, int, float, or bool. Nested dicts are flattened with dot notation. """ sanitized = {} for key, value in metadata.items(): if isinstance(value, dict): # Flatten nested dicts for nested_key, nested_value in value.items(): flat_key = f"{key}.{nested_key}" if isinstance(nested_value, (str, int, float, bool, type(None))): sanitized[flat_key] = nested_value if nested_value is not None else "none" elif isinstance(value, (str, int, float, bool, type(None))): sanitized[key] = value if value is not None else "none" else: # Convert other types to string sanitized[key] = str(value) return sanitized def get_memories(self, current_situation, n_matches=1): """Find matching recommendations using OpenAI embeddings""" query_embedding = self.get_embedding(current_situation) results = self.situation_collection.query( query_embeddings=[query_embedding], n_results=n_matches, include=["metadatas", "documents", "distances"], ) matched_results = [] for i in range(len(results["documents"][0])): matched_results.append( { "matched_situation": results["documents"][0][i], "recommendation": results["metadatas"][0][i]["recommendation"], "similarity_score": 1 - results["distances"][0][i], } ) return matched_results def get_memories_hybrid( self, current_situation: str, signal_filters: Optional[Dict[str, Any]] = None, n_matches: int = 3, min_similarity: float = 0.5, ) -> List[Dict[str, Any]]: """ Hybrid search: Filter by structured signals, then rank by embedding similarity. Args: current_situation: Text description of current market situation signal_filters: Dict of structured signals to filter by (e.g., {"unusual_volume": True}) Supports exact matches and can use dot notation for nested fields e.g., {"structured_signals.unusual_volume": True} n_matches: Number of results to return min_similarity: Minimum similarity score (0-1) to include in results Returns: List of dicts with keys: - matched_situation: Historical situation text - recommendation: What was recommended - similarity_score: Embedding similarity (0-1) - metadata: Full metadata including outcome, signals, etc. """ query_embedding = self.get_embedding(current_situation) # Build where clause for filtering where_clause = None if signal_filters: where_clause = {} for key, value in signal_filters.items(): where_clause[key] = value # Query ChromaDB with optional filtering query_params = { "query_embeddings": [query_embedding], "n_results": min(n_matches * 3, 100), # Get more results for filtering "include": ["metadatas", "documents", "distances"], } if where_clause: query_params["where"] = where_clause results = self.situation_collection.query(**query_params) # Process and filter results matched_results = [] for i in range(len(results["documents"][0])): similarity_score = 1 - results["distances"][0][i] # Apply similarity threshold if similarity_score < min_similarity: continue metadata = results["metadatas"][0][i] matched_results.append( { "matched_situation": results["documents"][0][i], "recommendation": metadata.get("recommendation", ""), "similarity_score": similarity_score, "metadata": metadata, # Extract key fields for convenience "ticker": metadata.get("ticker", ""), "move_pct": metadata.get("move_pct", 0), "move_direction": metadata.get("move_direction", ""), "was_correct": metadata.get("was_correct", False), "days_before_move": metadata.get("days_before_move", 0), } ) # Return top n_matches return matched_results[:n_matches] def get_statistics(self) -> Dict[str, Any]: """ Get statistics about the memory bank. Returns: Dict with keys: - total_memories: Total number of stored memories - accuracy_rate: % of memories where was_correct=True - avg_move_pct: Average percentage move in stored outcomes - signal_distribution: Count of different signal patterns """ total_count = self.situation_collection.count() if total_count == 0: return { "total_memories": 0, "accuracy_rate": 0.0, "avg_move_pct": 0.0, "signal_distribution": {}, } # Get all memories all_results = self.situation_collection.get(include=["metadatas"]) metadatas = all_results["metadatas"] # Calculate statistics correct_count = sum(1 for m in metadatas if m.get("was_correct") == True) accuracy_rate = (correct_count / total_count * 100) if total_count > 0 else 0 move_pcts = [m.get("move_pct", 0) for m in metadatas if "move_pct" in m] avg_move_pct = sum(move_pcts) / len(move_pcts) if move_pcts else 0 # Count signal patterns signal_distribution = {} for metadata in metadatas: for key, value in metadata.items(): if key.startswith("structured_signals."): signal_name = key.replace("structured_signals.", "") if signal_name not in signal_distribution: signal_distribution[signal_name] = {} if value not in signal_distribution[signal_name]: signal_distribution[signal_name][value] = 0 signal_distribution[signal_name][value] += 1 return { "total_memories": total_count, "accuracy_rate": accuracy_rate, "avg_move_pct": avg_move_pct, "signal_distribution": signal_distribution, } if __name__ == "__main__": # Example usage matcher = FinancialSituationMemory() # Example data example_data = [ ( "High inflation rate with rising interest rates and declining consumer spending", "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.", ), ( "Tech sector showing high volatility with increasing institutional selling pressure", "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.", ), ( "Strong dollar affecting emerging markets with increasing forex volatility", "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.", ), ( "Market showing signs of sector rotation with rising yields", "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.", ), ] # Add the example situations and recommendations matcher.add_situations(example_data) # Example query current_situation = """ Market showing increased volatility in tech sector, with institutional investors reducing positions and rising interest rates affecting growth stock valuations """ try: recommendations = matcher.get_memories(current_situation, n_matches=2) for i, rec in enumerate(recommendations, 1): logger.info(f"Match {i}:") logger.info(f"Similarity Score: {rec['similarity_score']:.2f}") logger.info(f"Matched Situation: {rec['matched_situation']}") logger.info(f"Recommendation: {rec['recommendation']}") except Exception as e: logger.error(f"Error during recommendation: {str(e)}")