474 lines
17 KiB
Python
474 lines
17 KiB
Python
import chromadb
|
|
from chromadb.config import Settings
|
|
from openai import OpenAI
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
import time
|
|
|
|
from tradingagents.utils.logging_config import (
|
|
get_logger,
|
|
get_api_logger,
|
|
get_performance_logger,
|
|
)
|
|
|
|
logger = get_logger("tradingagents.memory", component="MEMORY")
|
|
api_logger = get_api_logger()
|
|
perf_logger = get_performance_logger()
|
|
|
|
|
|
class FinancialSituationMemory:
|
|
"""
|
|
Memory system for financial trading agents with support for multiple embedding providers.
|
|
|
|
Supports:
|
|
- OpenAI embeddings
|
|
- Ollama local embeddings
|
|
- Graceful fallback when embeddings are unavailable
|
|
"""
|
|
|
|
def __init__(self, name: str, config: Dict[str, Any]):
|
|
"""
|
|
Initialize the financial situation memory.
|
|
|
|
Args:
|
|
name: Name of the memory collection
|
|
config: Configuration dictionary containing embedding settings
|
|
"""
|
|
self.name = name
|
|
self.config = config
|
|
self.enabled = config.get("enable_memory", True)
|
|
|
|
# Initialize embedding client and model based on provider
|
|
self.embedding_provider = config.get("embedding_provider", "openai").lower()
|
|
self.embedding_model = self._get_embedding_model()
|
|
self.embedding_backend_url = config.get(
|
|
"embedding_backend_url", "https://api.openai.com/v1"
|
|
)
|
|
|
|
# Initialize OpenAI client for embeddings (if enabled and supported)
|
|
self.client = None
|
|
if self.enabled and self.embedding_provider in ["openai", "ollama"]:
|
|
try:
|
|
start_time = time.time()
|
|
self.client = OpenAI(base_url=self.embedding_backend_url)
|
|
init_duration = (time.time() - start_time) * 1000
|
|
|
|
logger.info(
|
|
f"Initialized embedding client for '{name}'",
|
|
extra={
|
|
"context": {
|
|
"provider": self.embedding_provider,
|
|
"backend_url": self.embedding_backend_url,
|
|
"model": self.embedding_model,
|
|
"init_time_ms": init_duration,
|
|
}
|
|
},
|
|
)
|
|
perf_logger.log_timing(
|
|
"embedding_client_init",
|
|
init_duration,
|
|
{"provider": self.embedding_provider},
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to initialize embedding client for '{name}': {e}. Memory will be disabled.",
|
|
extra={
|
|
"context": {
|
|
"provider": self.embedding_provider,
|
|
"error": str(e),
|
|
}
|
|
},
|
|
)
|
|
self.enabled = False
|
|
elif not self.enabled:
|
|
logger.info(f"Memory disabled for '{name}' (enable_memory=False)")
|
|
elif self.embedding_provider == "none":
|
|
logger.info(
|
|
f"Embedding provider set to 'none' for '{name}'. Memory will be disabled."
|
|
)
|
|
self.enabled = False
|
|
else:
|
|
logger.warning(
|
|
f"Unsupported embedding provider '{self.embedding_provider}' for '{name}'. Memory will be disabled."
|
|
)
|
|
self.enabled = False
|
|
|
|
# Initialize ChromaDB collection
|
|
self.chroma_client = None
|
|
self.situation_collection = None
|
|
if self.enabled:
|
|
try:
|
|
start_time = time.time()
|
|
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
|
self.situation_collection = self.chroma_client.create_collection(
|
|
name=name
|
|
)
|
|
init_duration = (time.time() - start_time) * 1000
|
|
|
|
logger.info(
|
|
f"Initialized ChromaDB collection '{name}'",
|
|
extra={
|
|
"context": {"collection": name, "init_time_ms": init_duration}
|
|
},
|
|
)
|
|
perf_logger.log_timing(
|
|
"chromadb_collection_init", init_duration, {"collection": name}
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to initialize ChromaDB collection '{name}': {e}. Memory will be disabled.",
|
|
extra={"context": {"collection": name, "error": str(e)}},
|
|
)
|
|
self.enabled = False
|
|
|
|
def _get_embedding_model(self) -> str:
|
|
"""
|
|
Get the appropriate embedding model based on the provider.
|
|
|
|
Returns:
|
|
str: The embedding model name
|
|
"""
|
|
# Check if explicitly configured
|
|
if "embedding_model" in self.config:
|
|
return self.config["embedding_model"]
|
|
|
|
# Fall back to provider-specific defaults
|
|
if self.embedding_provider == "ollama":
|
|
return "nomic-embed-text"
|
|
elif self.embedding_provider == "openai":
|
|
return "text-embedding-3-small"
|
|
else:
|
|
return "text-embedding-3-small" # Safe default
|
|
|
|
def get_embedding(self, text: str) -> Optional[List[float]]:
|
|
"""
|
|
Get embedding for a text using the configured provider.
|
|
|
|
Args:
|
|
text: The text to embed
|
|
|
|
Returns:
|
|
List of floats representing the embedding, or None if embedding fails
|
|
"""
|
|
if not self.enabled or not self.client:
|
|
logger.debug("Embedding skipped (memory disabled or no client)")
|
|
return None
|
|
|
|
try:
|
|
start_time = time.time()
|
|
response = self.client.embeddings.create(
|
|
model=self.embedding_model, input=text
|
|
)
|
|
duration = (time.time() - start_time) * 1000
|
|
|
|
embedding = response.data[0].embedding
|
|
|
|
# Log API call
|
|
api_logger.log_call(
|
|
provider=self.embedding_provider,
|
|
model=self.embedding_model,
|
|
endpoint="embeddings.create",
|
|
tokens=len(text.split()), # Rough estimate
|
|
duration=duration,
|
|
status="success",
|
|
)
|
|
|
|
logger.debug(
|
|
f"Generated embedding for text ({len(text)} chars)",
|
|
extra={
|
|
"context": {
|
|
"provider": self.embedding_provider,
|
|
"model": self.embedding_model,
|
|
"text_length": len(text),
|
|
"duration_ms": duration,
|
|
}
|
|
},
|
|
)
|
|
|
|
return embedding
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to get embedding: {e}",
|
|
extra={
|
|
"context": {
|
|
"provider": self.embedding_provider,
|
|
"model": self.embedding_model,
|
|
"text_length": len(text),
|
|
"error": str(e),
|
|
}
|
|
},
|
|
)
|
|
|
|
# Log failed API call
|
|
api_logger.log_call(
|
|
provider=self.embedding_provider,
|
|
model=self.embedding_model,
|
|
endpoint="embeddings.create",
|
|
status="error",
|
|
error=str(e),
|
|
)
|
|
|
|
return None
|
|
|
|
def add_situations(self, situations_and_advice: List[Tuple[str, str]]) -> bool:
|
|
"""
|
|
Add financial situations and their corresponding advice.
|
|
|
|
Args:
|
|
situations_and_advice: List of tuples (situation, recommendation)
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
if not self.enabled:
|
|
logger.debug(
|
|
f"Memory disabled for '{self.name}', skipping add_situations",
|
|
extra={
|
|
"context": {
|
|
"collection": self.name,
|
|
"count": len(situations_and_advice),
|
|
}
|
|
},
|
|
)
|
|
return False
|
|
|
|
try:
|
|
start_time = time.time()
|
|
situations = []
|
|
advice = []
|
|
ids = []
|
|
embeddings = []
|
|
|
|
offset = self.situation_collection.count()
|
|
|
|
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
|
embedding = self.get_embedding(situation)
|
|
if embedding is None:
|
|
logger.warning(
|
|
f"Failed to get embedding for situation {i} in '{self.name}', skipping",
|
|
extra={
|
|
"context": {
|
|
"collection": self.name,
|
|
"situation_index": i,
|
|
"situation_preview": situation[:100],
|
|
}
|
|
},
|
|
)
|
|
continue
|
|
|
|
situations.append(situation)
|
|
advice.append(recommendation)
|
|
ids.append(str(offset + i))
|
|
embeddings.append(embedding)
|
|
|
|
if not situations:
|
|
logger.warning(
|
|
f"No valid situations to add to '{self.name}'",
|
|
extra={
|
|
"context": {
|
|
"collection": self.name,
|
|
"attempted": len(situations_and_advice),
|
|
}
|
|
},
|
|
)
|
|
return False
|
|
|
|
self.situation_collection.add(
|
|
documents=situations,
|
|
metadatas=[{"recommendation": rec} for rec in advice],
|
|
embeddings=embeddings,
|
|
ids=ids,
|
|
)
|
|
|
|
duration = (time.time() - start_time) * 1000
|
|
|
|
logger.info(
|
|
f"Added {len(situations)} situations to '{self.name}'",
|
|
extra={
|
|
"context": {
|
|
"collection": self.name,
|
|
"count": len(situations),
|
|
"total_in_collection": self.situation_collection.count(),
|
|
"duration_ms": duration,
|
|
}
|
|
},
|
|
)
|
|
|
|
perf_logger.log_timing(
|
|
"add_situations",
|
|
duration,
|
|
{"collection": self.name, "count": len(situations)},
|
|
)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to add situations to '{self.name}': {e}",
|
|
extra={
|
|
"context": {
|
|
"collection": self.name,
|
|
"attempted_count": len(situations_and_advice),
|
|
"error": str(e),
|
|
}
|
|
},
|
|
)
|
|
return False
|
|
|
|
def get_memories(
|
|
self, current_situation: str, n_matches: int = 1
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Find matching recommendations using embeddings.
|
|
|
|
Args:
|
|
current_situation: The current situation to match against
|
|
n_matches: Number of matches to return
|
|
|
|
Returns:
|
|
List of dictionaries containing matched situations and recommendations.
|
|
Returns empty list if memory is disabled or query fails.
|
|
"""
|
|
if not self.enabled:
|
|
logger.debug(
|
|
f"Memory disabled for '{self.name}', returning empty memories",
|
|
extra={"context": {"collection": self.name}},
|
|
)
|
|
return []
|
|
|
|
try:
|
|
start_time = time.time()
|
|
|
|
query_embedding = self.get_embedding(current_situation)
|
|
if query_embedding is None:
|
|
logger.warning(
|
|
f"Failed to get query embedding for '{self.name}', returning empty memories",
|
|
extra={
|
|
"context": {
|
|
"collection": self.name,
|
|
"situation_preview": current_situation[:100],
|
|
}
|
|
},
|
|
)
|
|
return []
|
|
|
|
results = self.situation_collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=n_matches,
|
|
include=["metadatas", "documents", "distances"],
|
|
)
|
|
|
|
matched_results = []
|
|
for i in range(len(results["documents"][0])):
|
|
similarity = 1 - results["distances"][0][i]
|
|
matched_results.append(
|
|
{
|
|
"matched_situation": results["documents"][0][i],
|
|
"recommendation": results["metadatas"][0][i]["recommendation"],
|
|
"similarity_score": similarity,
|
|
}
|
|
)
|
|
|
|
duration = (time.time() - start_time) * 1000
|
|
|
|
logger.info(
|
|
f"Retrieved {len(matched_results)} memories from '{self.name}'",
|
|
extra={
|
|
"context": {
|
|
"collection": self.name,
|
|
"requested": n_matches,
|
|
"returned": len(matched_results),
|
|
"top_similarity": matched_results[0]["similarity_score"]
|
|
if matched_results
|
|
else 0,
|
|
"duration_ms": duration,
|
|
}
|
|
},
|
|
)
|
|
|
|
perf_logger.log_timing(
|
|
"get_memories",
|
|
duration,
|
|
{"collection": self.name, "n_matches": n_matches},
|
|
)
|
|
|
|
return matched_results
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to get memories from '{self.name}': {e}",
|
|
extra={
|
|
"context": {
|
|
"collection": self.name,
|
|
"n_matches": n_matches,
|
|
"error": str(e),
|
|
}
|
|
},
|
|
)
|
|
return []
|
|
|
|
def is_enabled(self) -> bool:
|
|
"""Check if memory is enabled and functioning."""
|
|
return self.enabled
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage with OpenAI
|
|
print("=== Testing with OpenAI provider ===")
|
|
config_openai = {
|
|
"embedding_provider": "openai",
|
|
"embedding_model": "text-embedding-3-small",
|
|
"embedding_backend_url": "https://api.openai.com/v1",
|
|
"enable_memory": True,
|
|
}
|
|
|
|
matcher = FinancialSituationMemory("test_memory", config_openai)
|
|
|
|
if matcher.is_enabled():
|
|
# Example data
|
|
example_data = [
|
|
(
|
|
"High inflation rate with rising interest rates and declining consumer spending",
|
|
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
|
),
|
|
(
|
|
"Tech sector showing high volatility with increasing institutional selling pressure",
|
|
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
|
),
|
|
(
|
|
"Strong dollar affecting emerging markets with increasing forex volatility",
|
|
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
|
),
|
|
(
|
|
"Market showing signs of sector rotation with rising yields",
|
|
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
|
),
|
|
]
|
|
|
|
# Add the example situations and recommendations
|
|
if matcher.add_situations(example_data):
|
|
# Example query
|
|
current_situation = """
|
|
Market showing increased volatility in tech sector, with institutional investors
|
|
reducing positions and rising interest rates affecting growth stock valuations
|
|
"""
|
|
|
|
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
|
|
|
for i, rec in enumerate(recommendations, 1):
|
|
print(f"\nMatch {i}:")
|
|
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
|
print(f"Matched Situation: {rec['matched_situation']}")
|
|
print(f"Recommendation: {rec['recommendation']}")
|
|
else:
|
|
print("Failed to add situations")
|
|
else:
|
|
print("Memory is disabled")
|
|
|
|
print("\n=== Testing with disabled memory ===")
|
|
config_disabled = {"embedding_provider": "none", "enable_memory": False}
|
|
|
|
matcher_disabled = FinancialSituationMemory("test_disabled", config_disabled)
|
|
print(f"Memory enabled: {matcher_disabled.is_enabled()}")
|
|
result = matcher_disabled.get_memories("test situation")
|
|
print(f"Get memories result: {result}")
|