This commit is contained in:
Sam Chen 2025-07-14 22:21:58 -05:00 committed by GitHub
commit f46d5a9c61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 11 deletions

View File

@ -1,22 +1,43 @@
import chromadb import chromadb
from chromadb.config import Settings from chromadb.config import Settings
from openai import OpenAI from openai import OpenAI
import uuid
import time
class FinancialSituationMemory: class FinancialSituationMemory:
def __init__(self, name, config): def __init__(self, name, config, session_id=None):
if config["backend_url"] == "http://localhost:11434/v1": if config["backend_url"] == "http://localhost:11434/v1":
self.embedding = "nomic-embed-text" self.embedding = "nomic-embed-text"
else: else:
self.embedding = "text-embedding-3-small" self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"]) self.openai_client = OpenAI(base_url=config["backend_url"])
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name) # Generate session ID if not provided
if session_id is None:
session_id = str(uuid.uuid4())
self.session_id = session_id
self.collection_name = f"{name}_{session_id}"
# Initialize ChromaDB client
chroma_path = config.get("chroma_db_path", "./chroma_db")
settings = Settings(allow_reset=True)
self.chroma_client = chromadb.PersistentClient(path=chroma_path, settings=settings)
# Get or create collection to avoid conflicts
self.situation_collection = self._get_or_create_collection()
def _get_or_create_collection(self):
"""Get existing collection or create new one to avoid conflicts"""
try:
return self.chroma_client.get_collection(name=self.collection_name)
except Exception:
return self.chroma_client.create_collection(name=self.collection_name)
def get_embedding(self, text): def get_embedding(self, text):
"""Get OpenAI embedding for a text""" """Get OpenAI embedding for a text"""
response = self.client.embeddings.create( response = self.openai_client.embeddings.create(
model=self.embedding, input=text model=self.embedding, input=text
) )
return response.data[0].embedding return response.data[0].embedding
@ -66,6 +87,16 @@ class FinancialSituationMemory:
return matched_results return matched_results
def cleanup(self):
"""
Clean up the collection (optional - for resource management).
This method can be called to remove the collection when no longer needed.
"""
try:
self.chroma_client.delete_collection(self.collection_name)
except Exception:
pass
if __name__ == "__main__": if __name__ == "__main__":
# Example usage # Example usage

View File

@ -3,6 +3,7 @@
import os import os
from pathlib import Path from pathlib import Path
import json import json
import uuid
from datetime import date from datetime import date
from typing import Dict, Any, Tuple, List, Optional from typing import Dict, Any, Tuple, List, Optional
@ -37,6 +38,7 @@ class TradingAgentsGraph:
selected_analysts=["market", "social", "news", "fundamentals"], selected_analysts=["market", "social", "news", "fundamentals"],
debug=False, debug=False,
config: Dict[str, Any] = None, config: Dict[str, Any] = None,
session_id: Optional[str] = None,
): ):
"""Initialize the trading agents graph and components. """Initialize the trading agents graph and components.
@ -44,9 +46,13 @@ class TradingAgentsGraph:
selected_analysts: List of analyst types to include selected_analysts: List of analyst types to include
debug: Whether to run in debug mode debug: Whether to run in debug mode
config: Configuration dictionary. If None, uses default config config: Configuration dictionary. If None, uses default config
session_id: Optional unique session ID. If None, generates a new one
""" """
self.debug = debug self.debug = debug
self.config = config or DEFAULT_CONFIG self.config = config or DEFAULT_CONFIG
# Generate unique session ID
self.session_id = session_id or str(uuid.uuid4())
# Update the interface's config # Update the interface's config
set_config(self.config) set_config(self.config)
@ -72,12 +78,12 @@ class TradingAgentsGraph:
self.toolkit = Toolkit(config=self.config) self.toolkit = Toolkit(config=self.config)
# Initialize memories # Initialize memories with coordinated session ID
self.bull_memory = FinancialSituationMemory("bull_memory", self.config) self.bull_memory = FinancialSituationMemory("bull_memory", self.config, self.session_id)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config) self.bear_memory = FinancialSituationMemory("bear_memory", self.config, self.session_id)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config) self.trader_memory = FinancialSituationMemory("trader_memory", self.config, self.session_id)
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config) self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config, self.session_id)
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config) self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config, self.session_id)
# Create tool nodes # Create tool nodes
self.tool_nodes = self._create_tool_nodes() self.tool_nodes = self._create_tool_nodes()
@ -192,6 +198,7 @@ class TradingAgentsGraph:
def _log_state(self, trade_date, final_state): def _log_state(self, trade_date, final_state):
"""Log the final state to a JSON file.""" """Log the final state to a JSON file."""
self.log_states_dict[str(trade_date)] = { self.log_states_dict[str(trade_date)] = {
"session_id": self.session_id,
"company_of_interest": final_state["company_of_interest"], "company_of_interest": final_state["company_of_interest"],
"trade_date": final_state["trade_date"], "trade_date": final_state["trade_date"],
"market_report": final_state["market_report"], "market_report": final_state["market_report"],