From 9dffc288ac3d5e2474c5958ed7e75c324ada315c Mon Sep 17 00:00:00 2001 From: samchenku <144151219+samchenku@users.noreply.github.com> Date: Mon, 14 Jul 2025 17:40:45 -0500 Subject: [PATCH] Memory by session --- tradingagents/agents/utils/memory.py | 41 ++++++++++++++++++++++++---- tradingagents/graph/trading_graph.py | 19 +++++++++---- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 69b8ab8c..27a50e79 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,22 +1,43 @@ import chromadb from chromadb.config import Settings from openai import OpenAI +import uuid +import time class FinancialSituationMemory: - def __init__(self, name, config): + def __init__(self, name, config, session_id=None): if config["backend_url"] == "http://localhost:11434/v1": self.embedding = "nomic-embed-text" else: self.embedding = "text-embedding-3-small" - self.client = OpenAI(base_url=config["backend_url"]) - self.chroma_client = chromadb.Client(Settings(allow_reset=True)) - self.situation_collection = self.chroma_client.create_collection(name=name) + self.openai_client = OpenAI(base_url=config["backend_url"]) + + # Generate session ID if not provided + if session_id is None: + session_id = str(uuid.uuid4()) + self.session_id = session_id + self.collection_name = f"{name}_{session_id}" + + # Initialize ChromaDB client + chroma_path = config.get("chroma_db_path", "./chroma_db") + settings = Settings(allow_reset=True) + self.chroma_client = chromadb.PersistentClient(path=chroma_path, settings=settings) + + # Get or create collection to avoid conflicts + self.situation_collection = self._get_or_create_collection() + + def _get_or_create_collection(self): + """Get existing collection or create new one to avoid conflicts""" + try: + return self.chroma_client.get_collection(name=self.collection_name) + except Exception: + return self.chroma_client.create_collection(name=self.collection_name) def get_embedding(self, text): """Get OpenAI embedding for a text""" - response = self.client.embeddings.create( + response = self.openai_client.embeddings.create( model=self.embedding, input=text ) return response.data[0].embedding @@ -66,6 +87,16 @@ class FinancialSituationMemory: return matched_results + def cleanup(self): + """ + Clean up the collection (optional - for resource management). + This method can be called to remove the collection when no longer needed. + """ + try: + self.chroma_client.delete_collection(self.collection_name) + except Exception: + pass + if __name__ == "__main__": # Example usage diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 80a29e53..812110fd 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -3,6 +3,7 @@ import os from pathlib import Path import json +import uuid from datetime import date from typing import Dict, Any, Tuple, List, Optional @@ -37,6 +38,7 @@ class TradingAgentsGraph: selected_analysts=["market", "social", "news", "fundamentals"], debug=False, config: Dict[str, Any] = None, + session_id: Optional[str] = None, ): """Initialize the trading agents graph and components. @@ -44,9 +46,13 @@ class TradingAgentsGraph: selected_analysts: List of analyst types to include debug: Whether to run in debug mode config: Configuration dictionary. If None, uses default config + session_id: Optional unique session ID. If None, generates a new one """ self.debug = debug self.config = config or DEFAULT_CONFIG + + # Generate unique session ID + self.session_id = session_id or str(uuid.uuid4()) # Update the interface's config set_config(self.config) @@ -72,12 +78,12 @@ class TradingAgentsGraph: self.toolkit = Toolkit(config=self.config) - # Initialize memories - self.bull_memory = FinancialSituationMemory("bull_memory", self.config) - self.bear_memory = FinancialSituationMemory("bear_memory", self.config) - self.trader_memory = FinancialSituationMemory("trader_memory", self.config) - self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config) - self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config) + # Initialize memories with coordinated session ID + self.bull_memory = FinancialSituationMemory("bull_memory", self.config, self.session_id) + self.bear_memory = FinancialSituationMemory("bear_memory", self.config, self.session_id) + self.trader_memory = FinancialSituationMemory("trader_memory", self.config, self.session_id) + self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config, self.session_id) + self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config, self.session_id) # Create tool nodes self.tool_nodes = self._create_tool_nodes() @@ -192,6 +198,7 @@ class TradingAgentsGraph: def _log_state(self, trade_date, final_state): """Log the final state to a JSON file.""" self.log_states_dict[str(trade_date)] = { + "session_id": self.session_id, "company_of_interest": final_state["company_of_interest"], "trade_date": final_state["trade_date"], "market_report": final_state["market_report"],