Merge 9dffc288ac into a438acdbbd
This commit is contained in:
commit
f46d5a9c61
|
|
@ -1,22 +1,43 @@
|
|||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from openai import OpenAI
|
||||
import uuid
|
||||
import time
|
||||
|
||||
|
||||
class FinancialSituationMemory:
|
||||
def __init__(self, name, config):
|
||||
def __init__(self, name, config, session_id=None):
|
||||
if config["backend_url"] == "http://localhost:11434/v1":
|
||||
self.embedding = "nomic-embed-text"
|
||||
else:
|
||||
self.embedding = "text-embedding-3-small"
|
||||
self.client = OpenAI(base_url=config["backend_url"])
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
self.openai_client = OpenAI(base_url=config["backend_url"])
|
||||
|
||||
# Generate session ID if not provided
|
||||
if session_id is None:
|
||||
session_id = str(uuid.uuid4())
|
||||
self.session_id = session_id
|
||||
self.collection_name = f"{name}_{session_id}"
|
||||
|
||||
# Initialize ChromaDB client
|
||||
chroma_path = config.get("chroma_db_path", "./chroma_db")
|
||||
settings = Settings(allow_reset=True)
|
||||
self.chroma_client = chromadb.PersistentClient(path=chroma_path, settings=settings)
|
||||
|
||||
# Get or create collection to avoid conflicts
|
||||
self.situation_collection = self._get_or_create_collection()
|
||||
|
||||
def _get_or_create_collection(self):
|
||||
"""Get existing collection or create new one to avoid conflicts"""
|
||||
try:
|
||||
return self.chroma_client.get_collection(name=self.collection_name)
|
||||
except Exception:
|
||||
return self.chroma_client.create_collection(name=self.collection_name)
|
||||
|
||||
def get_embedding(self, text):
|
||||
"""Get OpenAI embedding for a text"""
|
||||
|
||||
response = self.client.embeddings.create(
|
||||
response = self.openai_client.embeddings.create(
|
||||
model=self.embedding, input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
|
@ -66,6 +87,16 @@ class FinancialSituationMemory:
|
|||
|
||||
return matched_results
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Clean up the collection (optional - for resource management).
|
||||
This method can be called to remove the collection when no longer needed.
|
||||
"""
|
||||
try:
|
||||
self.chroma_client.delete_collection(self.collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import json
|
||||
import uuid
|
||||
from datetime import date
|
||||
from typing import Dict, Any, Tuple, List, Optional
|
||||
|
||||
|
|
@ -37,6 +38,7 @@ class TradingAgentsGraph:
|
|||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
debug=False,
|
||||
config: Dict[str, Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the trading agents graph and components.
|
||||
|
||||
|
|
@ -44,9 +46,13 @@ class TradingAgentsGraph:
|
|||
selected_analysts: List of analyst types to include
|
||||
debug: Whether to run in debug mode
|
||||
config: Configuration dictionary. If None, uses default config
|
||||
session_id: Optional unique session ID. If None, generates a new one
|
||||
"""
|
||||
self.debug = debug
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
|
||||
# Generate unique session ID
|
||||
self.session_id = session_id or str(uuid.uuid4())
|
||||
|
||||
# Update the interface's config
|
||||
set_config(self.config)
|
||||
|
|
@ -72,12 +78,12 @@ class TradingAgentsGraph:
|
|||
|
||||
self.toolkit = Toolkit(config=self.config)
|
||||
|
||||
# Initialize memories
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
||||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
||||
# Initialize memories with coordinated session ID
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config, self.session_id)
|
||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config, self.session_id)
|
||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config, self.session_id)
|
||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config, self.session_id)
|
||||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config, self.session_id)
|
||||
|
||||
# Create tool nodes
|
||||
self.tool_nodes = self._create_tool_nodes()
|
||||
|
|
@ -192,6 +198,7 @@ class TradingAgentsGraph:
|
|||
def _log_state(self, trade_date, final_state):
|
||||
"""Log the final state to a JSON file."""
|
||||
self.log_states_dict[str(trade_date)] = {
|
||||
"session_id": self.session_id,
|
||||
"company_of_interest": final_state["company_of_interest"],
|
||||
"trade_date": final_state["trade_date"],
|
||||
"market_report": final_state["market_report"],
|
||||
|
|
|
|||
Loading…
Reference in New Issue