Merge 9dffc288ac into a438acdbbd
This commit is contained in:
commit
f46d5a9c61
|
|
@ -1,22 +1,43 @@
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
class FinancialSituationMemory:
|
class FinancialSituationMemory:
|
||||||
def __init__(self, name, config):
|
def __init__(self, name, config, session_id=None):
|
||||||
if config["backend_url"] == "http://localhost:11434/v1":
|
if config["backend_url"] == "http://localhost:11434/v1":
|
||||||
self.embedding = "nomic-embed-text"
|
self.embedding = "nomic-embed-text"
|
||||||
else:
|
else:
|
||||||
self.embedding = "text-embedding-3-small"
|
self.embedding = "text-embedding-3-small"
|
||||||
self.client = OpenAI(base_url=config["backend_url"])
|
self.openai_client = OpenAI(base_url=config["backend_url"])
|
||||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
|
||||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
# Generate session ID if not provided
|
||||||
|
if session_id is None:
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
self.session_id = session_id
|
||||||
|
self.collection_name = f"{name}_{session_id}"
|
||||||
|
|
||||||
|
# Initialize ChromaDB client
|
||||||
|
chroma_path = config.get("chroma_db_path", "./chroma_db")
|
||||||
|
settings = Settings(allow_reset=True)
|
||||||
|
self.chroma_client = chromadb.PersistentClient(path=chroma_path, settings=settings)
|
||||||
|
|
||||||
|
# Get or create collection to avoid conflicts
|
||||||
|
self.situation_collection = self._get_or_create_collection()
|
||||||
|
|
||||||
|
def _get_or_create_collection(self):
|
||||||
|
"""Get existing collection or create new one to avoid conflicts"""
|
||||||
|
try:
|
||||||
|
return self.chroma_client.get_collection(name=self.collection_name)
|
||||||
|
except Exception:
|
||||||
|
return self.chroma_client.create_collection(name=self.collection_name)
|
||||||
|
|
||||||
def get_embedding(self, text):
|
def get_embedding(self, text):
|
||||||
"""Get OpenAI embedding for a text"""
|
"""Get OpenAI embedding for a text"""
|
||||||
|
|
||||||
response = self.client.embeddings.create(
|
response = self.openai_client.embeddings.create(
|
||||||
model=self.embedding, input=text
|
model=self.embedding, input=text
|
||||||
)
|
)
|
||||||
return response.data[0].embedding
|
return response.data[0].embedding
|
||||||
|
|
@ -66,6 +87,16 @@ class FinancialSituationMemory:
|
||||||
|
|
||||||
return matched_results
|
return matched_results
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""
|
||||||
|
Clean up the collection (optional - for resource management).
|
||||||
|
This method can be called to remove the collection when no longer needed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.chroma_client.delete_collection(self.collection_name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Example usage
|
# Example usage
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
|
import uuid
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from typing import Dict, Any, Tuple, List, Optional
|
from typing import Dict, Any, Tuple, List, Optional
|
||||||
|
|
||||||
|
|
@ -37,6 +38,7 @@ class TradingAgentsGraph:
|
||||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||||
debug=False,
|
debug=False,
|
||||||
config: Dict[str, Any] = None,
|
config: Dict[str, Any] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the trading agents graph and components.
|
"""Initialize the trading agents graph and components.
|
||||||
|
|
||||||
|
|
@ -44,9 +46,13 @@ class TradingAgentsGraph:
|
||||||
selected_analysts: List of analyst types to include
|
selected_analysts: List of analyst types to include
|
||||||
debug: Whether to run in debug mode
|
debug: Whether to run in debug mode
|
||||||
config: Configuration dictionary. If None, uses default config
|
config: Configuration dictionary. If None, uses default config
|
||||||
|
session_id: Optional unique session ID. If None, generates a new one
|
||||||
"""
|
"""
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.config = config or DEFAULT_CONFIG
|
self.config = config or DEFAULT_CONFIG
|
||||||
|
|
||||||
|
# Generate unique session ID
|
||||||
|
self.session_id = session_id or str(uuid.uuid4())
|
||||||
|
|
||||||
# Update the interface's config
|
# Update the interface's config
|
||||||
set_config(self.config)
|
set_config(self.config)
|
||||||
|
|
@ -72,12 +78,12 @@ class TradingAgentsGraph:
|
||||||
|
|
||||||
self.toolkit = Toolkit(config=self.config)
|
self.toolkit = Toolkit(config=self.config)
|
||||||
|
|
||||||
# Initialize memories
|
# Initialize memories with coordinated session ID
|
||||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
self.bull_memory = FinancialSituationMemory("bull_memory", self.config, self.session_id)
|
||||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
self.bear_memory = FinancialSituationMemory("bear_memory", self.config, self.session_id)
|
||||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
self.trader_memory = FinancialSituationMemory("trader_memory", self.config, self.session_id)
|
||||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config, self.session_id)
|
||||||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config, self.session_id)
|
||||||
|
|
||||||
# Create tool nodes
|
# Create tool nodes
|
||||||
self.tool_nodes = self._create_tool_nodes()
|
self.tool_nodes = self._create_tool_nodes()
|
||||||
|
|
@ -192,6 +198,7 @@ class TradingAgentsGraph:
|
||||||
def _log_state(self, trade_date, final_state):
|
def _log_state(self, trade_date, final_state):
|
||||||
"""Log the final state to a JSON file."""
|
"""Log the final state to a JSON file."""
|
||||||
self.log_states_dict[str(trade_date)] = {
|
self.log_states_dict[str(trade_date)] = {
|
||||||
|
"session_id": self.session_id,
|
||||||
"company_of_interest": final_state["company_of_interest"],
|
"company_of_interest": final_state["company_of_interest"],
|
||||||
"trade_date": final_state["trade_date"],
|
"trade_date": final_state["trade_date"],
|
||||||
"market_report": final_state["market_report"],
|
"market_report": final_state["market_report"],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue