Memory by session

This commit is contained in:
samchenku 2025-07-14 17:40:45 -05:00
parent a438acdbbd
commit 9dffc288ac
2 changed files with 49 additions and 11 deletions

View File

@ -1,22 +1,43 @@
import chromadb
from chromadb.config import Settings
from openai import OpenAI
import uuid
import time
class FinancialSituationMemory:
def __init__(self, name, config):
def __init__(self, name, config, session_id=None):
if config["backend_url"] == "http://localhost:11434/v1":
self.embedding = "nomic-embed-text"
else:
self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"])
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)
self.openai_client = OpenAI(base_url=config["backend_url"])
# Generate session ID if not provided
if session_id is None:
session_id = str(uuid.uuid4())
self.session_id = session_id
self.collection_name = f"{name}_{session_id}"
# Initialize ChromaDB client
chroma_path = config.get("chroma_db_path", "./chroma_db")
settings = Settings(allow_reset=True)
self.chroma_client = chromadb.PersistentClient(path=chroma_path, settings=settings)
# Get or create collection to avoid conflicts
self.situation_collection = self._get_or_create_collection()
def _get_or_create_collection(self):
"""Get existing collection or create new one to avoid conflicts"""
try:
return self.chroma_client.get_collection(name=self.collection_name)
except Exception:
return self.chroma_client.create_collection(name=self.collection_name)
def get_embedding(self, text):
"""Get OpenAI embedding for a text"""
response = self.client.embeddings.create(
response = self.openai_client.embeddings.create(
model=self.embedding, input=text
)
return response.data[0].embedding
@ -66,6 +87,16 @@ class FinancialSituationMemory:
return matched_results
def cleanup(self):
"""
Clean up the collection (optional - for resource management).
This method can be called to remove the collection when no longer needed.
"""
try:
self.chroma_client.delete_collection(self.collection_name)
except Exception:
pass
if __name__ == "__main__":
# Example usage

View File

@ -3,6 +3,7 @@
import os
from pathlib import Path
import json
import uuid
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
@ -37,6 +38,7 @@ class TradingAgentsGraph:
selected_analysts=["market", "social", "news", "fundamentals"],
debug=False,
config: Dict[str, Any] = None,
session_id: Optional[str] = None,
):
"""Initialize the trading agents graph and components.
@ -44,9 +46,13 @@ class TradingAgentsGraph:
selected_analysts: List of analyst types to include
debug: Whether to run in debug mode
config: Configuration dictionary. If None, uses default config
session_id: Optional unique session ID. If None, generates a new one
"""
self.debug = debug
self.config = config or DEFAULT_CONFIG
# Generate unique session ID
self.session_id = session_id or str(uuid.uuid4())
# Update the interface's config
set_config(self.config)
@ -72,12 +78,12 @@ class TradingAgentsGraph:
self.toolkit = Toolkit(config=self.config)
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
# Initialize memories with coordinated session ID
self.bull_memory = FinancialSituationMemory("bull_memory", self.config, self.session_id)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config, self.session_id)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config, self.session_id)
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config, self.session_id)
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config, self.session_id)
# Create tool nodes
self.tool_nodes = self._create_tool_nodes()
@ -192,6 +198,7 @@ class TradingAgentsGraph:
def _log_state(self, trade_date, final_state):
"""Log the final state to a JSON file."""
self.log_states_dict[str(trade_date)] = {
"session_id": self.session_id,
"company_of_interest": final_state["company_of_interest"],
"trade_date": final_state["trade_date"],
"market_report": final_state["market_report"],