TradingAgents/tradingagents/agents/utils/memory.py

287 lines
10 KiB
Python

import chromadb
from chromadb.config import Settings
from openai import OpenAI
import logging
from typing import List, Dict, Any, Optional, Tuple
logger = logging.getLogger(__name__)
class FinancialSituationMemory:
"""
Memory system for financial trading agents with support for multiple embedding providers.
Supports:
- OpenAI embeddings
- Ollama local embeddings
- Graceful fallback when embeddings are unavailable
"""
def __init__(self, name: str, config: Dict[str, Any]):
"""
Initialize the financial situation memory.
Args:
name: Name of the memory collection
config: Configuration dictionary containing embedding settings
"""
self.name = name
self.config = config
self.enabled = config.get("enable_memory", True)
# Initialize embedding client and model based on provider
self.embedding_provider = config.get("embedding_provider", "openai").lower()
self.embedding_model = self._get_embedding_model()
self.embedding_backend_url = config.get(
"embedding_backend_url", "https://api.openai.com/v1"
)
# Initialize OpenAI client for embeddings (if enabled and supported)
self.client = None
if self.enabled and self.embedding_provider in ["openai", "ollama"]:
try:
self.client = OpenAI(base_url=self.embedding_backend_url)
logger.info(
f"Initialized embedding client for provider: {self.embedding_provider}"
)
except Exception as e:
logger.warning(
f"Failed to initialize embedding client: {e}. Memory will be disabled."
)
self.enabled = False
elif not self.enabled:
logger.info(f"Memory disabled for {name}")
elif self.embedding_provider == "none":
logger.info(
f"Embedding provider set to 'none'. Memory will be disabled for {name}."
)
self.enabled = False
else:
logger.warning(
f"Unsupported embedding provider: {self.embedding_provider}. Memory will be disabled."
)
self.enabled = False
# Initialize ChromaDB collection
self.chroma_client = None
self.situation_collection = None
if self.enabled:
try:
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(
name=name
)
logger.info(f"Initialized ChromaDB collection: {name}")
except Exception as e:
logger.error(
f"Failed to initialize ChromaDB collection: {e}. Memory will be disabled."
)
self.enabled = False
def _get_embedding_model(self) -> str:
"""
Get the appropriate embedding model based on the provider.
Returns:
str: The embedding model name
"""
# Check if explicitly configured
if "embedding_model" in self.config:
return self.config["embedding_model"]
# Fall back to provider-specific defaults
if self.embedding_provider == "ollama":
return "nomic-embed-text"
elif self.embedding_provider == "openai":
return "text-embedding-3-small"
else:
return "text-embedding-3-small" # Safe default
def get_embedding(self, text: str) -> Optional[List[float]]:
"""
Get embedding for a text using the configured provider.
Args:
text: The text to embed
Returns:
List of floats representing the embedding, or None if embedding fails
"""
if not self.enabled or not self.client:
return None
try:
response = self.client.embeddings.create(
model=self.embedding_model, input=text
)
return response.data[0].embedding
except Exception as e:
logger.error(f"Failed to get embedding: {e}")
return None
def add_situations(self, situations_and_advice: List[Tuple[str, str]]) -> bool:
"""
Add financial situations and their corresponding advice.
Args:
situations_and_advice: List of tuples (situation, recommendation)
Returns:
bool: True if successful, False otherwise
"""
if not self.enabled:
logger.debug(f"Memory disabled for {self.name}, skipping add_situations")
return False
try:
situations = []
advice = []
ids = []
embeddings = []
offset = self.situation_collection.count()
for i, (situation, recommendation) in enumerate(situations_and_advice):
embedding = self.get_embedding(situation)
if embedding is None:
logger.warning(
f"Failed to get embedding for situation {i}, skipping"
)
continue
situations.append(situation)
advice.append(recommendation)
ids.append(str(offset + i))
embeddings.append(embedding)
if not situations:
logger.warning("No valid situations to add")
return False
self.situation_collection.add(
documents=situations,
metadatas=[{"recommendation": rec} for rec in advice],
embeddings=embeddings,
ids=ids,
)
logger.info(f"Added {len(situations)} situations to {self.name}")
return True
except Exception as e:
logger.error(f"Failed to add situations: {e}")
return False
def get_memories(
self, current_situation: str, n_matches: int = 1
) -> List[Dict[str, Any]]:
"""
Find matching recommendations using embeddings.
Args:
current_situation: The current situation to match against
n_matches: Number of matches to return
Returns:
List of dictionaries containing matched situations and recommendations.
Returns empty list if memory is disabled or query fails.
"""
if not self.enabled:
logger.debug(f"Memory disabled for {self.name}, returning empty memories")
return []
try:
query_embedding = self.get_embedding(current_situation)
if query_embedding is None:
logger.warning(
"Failed to get query embedding, returning empty memories"
)
return []
results = self.situation_collection.query(
query_embeddings=[query_embedding],
n_results=n_matches,
include=["metadatas", "documents", "distances"],
)
matched_results = []
for i in range(len(results["documents"][0])):
matched_results.append(
{
"matched_situation": results["documents"][0][i],
"recommendation": results["metadatas"][0][i]["recommendation"],
"similarity_score": 1 - results["distances"][0][i],
}
)
return matched_results
except Exception as e:
logger.error(f"Failed to get memories: {e}")
return []
def is_enabled(self) -> bool:
"""Check if memory is enabled and functioning."""
return self.enabled
if __name__ == "__main__":
# Example usage with OpenAI
print("=== Testing with OpenAI provider ===")
config_openai = {
"embedding_provider": "openai",
"embedding_model": "text-embedding-3-small",
"embedding_backend_url": "https://api.openai.com/v1",
"enable_memory": True,
}
matcher = FinancialSituationMemory("test_memory", config_openai)
if matcher.is_enabled():
# Example data
example_data = [
(
"High inflation rate with rising interest rates and declining consumer spending",
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
),
(
"Tech sector showing high volatility with increasing institutional selling pressure",
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
),
(
"Strong dollar affecting emerging markets with increasing forex volatility",
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
),
(
"Market showing signs of sector rotation with rising yields",
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
),
]
# Add the example situations and recommendations
if matcher.add_situations(example_data):
# Example query
current_situation = """
Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations
"""
recommendations = matcher.get_memories(current_situation, n_matches=2)
for i, rec in enumerate(recommendations, 1):
print(f"\nMatch {i}:")
print(f"Similarity Score: {rec['similarity_score']:.2f}")
print(f"Matched Situation: {rec['matched_situation']}")
print(f"Recommendation: {rec['recommendation']}")
else:
print("Failed to add situations")
else:
print("Memory is disabled")
print("\n=== Testing with disabled memory ===")
config_disabled = {"embedding_provider": "none", "enable_memory": False}
matcher_disabled = FinancialSituationMemory("test_disabled", config_disabled)
print(f"Memory enabled: {matcher_disabled.is_enabled()}")
result = matcher_disabled.get_memories("test situation")
print(f"Get memories result: {result}")