TradingAgents/tradingagents/agents/utils/memory.py

215 lines
8.3 KiB
Python

import chromadb
from chromadb.config import Settings
from openai import OpenAI
import numpy as np
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
class FinancialSituationMemory:
def __init__(self, name, config, symbol=None, persistent_dir=None):
if config["backend_url"] == "http://localhost:11434/v1":
self.embedding = "nomic-embed-text"
else:
self.embedding = "text-embedding-3-small"
# Get API key from config
try:
from ...dataflows.config import get_openai_api_key
api_key = get_openai_api_key()
except ImportError:
api_key = None
self.client = OpenAI(base_url=config["backend_url"], api_key=api_key)
# Use persistent storage if directory is provided
if persistent_dir:
os.makedirs(persistent_dir, exist_ok=True)
# Use PersistentClient for disk storage
chroma_settings = Settings(
anonymized_telemetry=False,
allow_reset=True,
is_persistent=True
)
try:
self.chroma_client = chromadb.PersistentClient(
path=persistent_dir,
settings=chroma_settings
)
except Exception:
# Fallback: try without settings if there are compatibility issues
self.chroma_client = chromadb.PersistentClient(path=persistent_dir)
else:
# Use in-memory client for backward compatibility
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
# Create collection name
if symbol:
collection_name = f"{name}_{symbol}"
else:
collection_name = name
# Sanitize collection name (ChromaDB requires alphanumeric, underscore, hyphen)
collection_name = collection_name.replace(' ', '_').replace('.', '_')
# Try to get existing collection or create new one
try:
self.situation_collection = self.chroma_client.get_collection(name=collection_name)
except:
self.situation_collection = self.chroma_client.create_collection(name=collection_name)
def get_embedding(self, text):
"""Get OpenAI embeddings for a text, using RecursiveCharacterTextSplitter for long texts.
Returns:
list: List of embeddings (one per chunk). If text is short, returns list with single embedding.
"""
# text-embedding-3-small has a max context length of 8192 tokens
# Conservative estimate: ~3 characters per token for safety margin
max_chars = 24000 # ~8000 tokens * 3 chars/token
if len(text) <= max_chars:
# Text is short enough, get embedding directly
response = self.client.embeddings.create(
model=self.embedding, input=text
)
return [response.data[0].embedding]
# Text is too long, use RecursiveCharacterTextSplitter
print(f"Text length {len(text)} exceeds limit, splitting into chunks for embedding")
# Use RecursiveCharacterTextSplitter for intelligent chunking
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=max_chars - 1000, # Leave some buffer
chunk_overlap=500, # Overlap to preserve context
length_function=len,
separators=["\n\n", "\n", ". ", " ", ""] # Try to split at natural boundaries
)
chunks = text_splitter.split_text(text)
print(f"Split text into {len(chunks)} chunks for embedding")
# Get embeddings for all chunks
chunk_embeddings = []
for i, chunk in enumerate(chunks):
try:
response = self.client.embeddings.create(
model=self.embedding, input=chunk
)
chunk_embeddings.append(response.data[0].embedding)
except Exception as e:
print(f"Failed to get embedding for chunk {i}: {e}")
continue
if not chunk_embeddings:
raise ValueError("Failed to get embeddings for any chunks")
return chunk_embeddings
def add_situations(self, situations_and_advice):
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
situations = []
advice = []
ids = []
embeddings = []
offset = self.situation_collection.count()
current_id = offset
for situation, recommendation in situations_and_advice:
# Get embeddings (returns list of embeddings for chunks)
situation_embeddings = self.get_embedding(situation)
# Add each chunk as a separate document
for chunk_idx, embedding in enumerate(situation_embeddings):
situations.append(situation) # Store full situation for each chunk
advice.append(recommendation)
ids.append(str(current_id))
embeddings.append(embedding)
current_id += 1
self.situation_collection.add(
documents=situations,
metadatas=[{"recommendation": rec} for rec in advice],
embeddings=embeddings,
ids=ids,
)
def get_memories(self, current_situation, n_matches=1):
"""Find matching recommendations using OpenAI embeddings"""
# Get embeddings (returns list)
query_embeddings = self.get_embedding(current_situation)
# Average embeddings if multiple chunks
if len(query_embeddings) > 1:
query_embedding = np.mean(query_embeddings, axis=0).tolist()
else:
query_embedding = query_embeddings[0]
results = self.situation_collection.query(
query_embeddings=[query_embedding],
n_results=n_matches,
include=["metadatas", "documents", "distances"],
)
matched_results = []
for i in range(len(results["documents"][0])):
matched_results.append(
{
"matched_situation": results["documents"][0][i],
"recommendation": results["metadatas"][0][i]["recommendation"],
"similarity_score": 1 - results["distances"][0][i],
}
)
return matched_results
if __name__ == "__main__":
# Example usage
matcher = FinancialSituationMemory()
# Example data
example_data = [
(
"High inflation rate with rising interest rates and declining consumer spending",
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
),
(
"Tech sector showing high volatility with increasing institutional selling pressure",
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
),
(
"Strong dollar affecting emerging markets with increasing forex volatility",
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
),
(
"Market showing signs of sector rotation with rising yields",
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
),
]
# Add the example situations and recommendations
matcher.add_situations(example_data)
# Example query
current_situation = """
Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations
"""
try:
recommendations = matcher.get_memories(current_situation, n_matches=2)
for i, rec in enumerate(recommendations, 1):
print(f"\nMatch {i}:")
print(f"Similarity Score: {rec['similarity_score']:.2f}")
print(f"Matched Situation: {rec['matched_situation']}")
print(f"Recommendation: {rec['recommendation']}")
except Exception as e:
print(f"Error during recommendation: {str(e)}")