338 lines
13 KiB
Python
338 lines
13 KiB
Python
import os
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import chromadb
|
|
from openai import OpenAI
|
|
|
|
from tradingagents.utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class FinancialSituationMemory:
|
|
def __init__(self, name, config):
|
|
# Determine embedding backend URL
|
|
# For Ollama, use the Ollama endpoint; otherwise default to OpenAI for embeddings
|
|
if config.get("backend_url") == "http://localhost:11434/v1":
|
|
self.embedding_backend = "http://localhost:11434/v1"
|
|
self.embedding = "nomic-embed-text"
|
|
else:
|
|
# Always use OpenAI for embeddings, regardless of LLM provider
|
|
self.embedding_backend = "https://api.openai.com/v1"
|
|
self.embedding = "text-embedding-3-small"
|
|
|
|
self.client = OpenAI(api_key=config.validate_key("openai_api_key", "OpenAI"))
|
|
|
|
# Use persistent storage in project directory
|
|
persist_directory = os.path.join(config.get("project_dir", "."), "memory_db")
|
|
os.makedirs(persist_directory, exist_ok=True)
|
|
|
|
self.chroma_client = chromadb.PersistentClient(path=persist_directory)
|
|
|
|
# Get or create collection
|
|
try:
|
|
self.situation_collection = self.chroma_client.get_collection(name=name)
|
|
except Exception:
|
|
self.situation_collection = self.chroma_client.create_collection(name=name)
|
|
|
|
def get_embedding(self, text):
|
|
"""Get OpenAI embedding for a text"""
|
|
|
|
response = self.client.embeddings.create(model=self.embedding, input=text)
|
|
return response.data[0].embedding
|
|
|
|
def _batch_add(
|
|
self,
|
|
documents: List[str],
|
|
metadatas: List[Dict[str, Any]],
|
|
embeddings: List[List[float]],
|
|
ids: List[str] = None,
|
|
):
|
|
"""Internal helper to batch add documents to ChromaDB."""
|
|
if not documents:
|
|
return
|
|
|
|
if ids is None:
|
|
offset = self.situation_collection.count()
|
|
ids = [str(offset + i) for i in range(len(documents))]
|
|
|
|
self.situation_collection.add(
|
|
documents=documents,
|
|
metadatas=metadatas,
|
|
embeddings=embeddings,
|
|
ids=ids,
|
|
)
|
|
|
|
def add_situations(self, situations_and_advice):
|
|
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
|
situations = []
|
|
metadatas = []
|
|
embeddings = []
|
|
|
|
for situation, recommendation in situations_and_advice:
|
|
situations.append(situation)
|
|
metadatas.append({"recommendation": recommendation})
|
|
embeddings.append(self.get_embedding(situation))
|
|
|
|
self._batch_add(situations, metadatas, embeddings)
|
|
|
|
def add_situations_with_metadata(
|
|
self, situations_and_outcomes: List[Tuple[str, str, Dict[str, Any]]]
|
|
):
|
|
"""
|
|
Add financial situations with enhanced metadata for learning system.
|
|
|
|
Args:
|
|
situations_and_outcomes: List of tuples (situation_text, recommendation, metadata)
|
|
where metadata contains:
|
|
- ticker: Stock symbol
|
|
- analysis_date: Date of analysis (YYYY-MM-DD)
|
|
- days_before_move: How many days before the major move (7 or 30)
|
|
- move_pct: Percentage move that occurred
|
|
- move_direction: "up" or "down"
|
|
- agent_recommendation: What the agent recommended
|
|
- was_correct: Boolean, whether recommendation matched outcome
|
|
- structured_signals: Dict of signal features (optional)
|
|
- unusual_volume: bool
|
|
- analyst_sentiment: str (bullish/bearish/neutral)
|
|
- news_sentiment: str (positive/negative/neutral)
|
|
- short_interest: str (high/medium/low)
|
|
- insider_activity: str (buying/selling/none)
|
|
- etc.
|
|
"""
|
|
situations = []
|
|
metadatas = []
|
|
embeddings = []
|
|
|
|
for situation, recommendation, metadata in situations_and_outcomes:
|
|
situations.append(situation)
|
|
embeddings.append(self.get_embedding(situation))
|
|
|
|
# Merge recommendation with metadata
|
|
full_metadata = {"recommendation": recommendation}
|
|
full_metadata.update(metadata)
|
|
|
|
# Ensure all metadata values are strings, numbers, or booleans for ChromaDB
|
|
full_metadata = self._sanitize_metadata(full_metadata)
|
|
metadatas.append(full_metadata)
|
|
|
|
self._batch_add(situations, metadatas, embeddings)
|
|
|
|
def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Sanitize metadata for ChromaDB compatibility.
|
|
ChromaDB requires metadata values to be str, int, float, or bool.
|
|
Nested dicts are flattened with dot notation.
|
|
"""
|
|
sanitized = {}
|
|
|
|
for key, value in metadata.items():
|
|
if isinstance(value, dict):
|
|
# Flatten nested dicts
|
|
for nested_key, nested_value in value.items():
|
|
flat_key = f"{key}.{nested_key}"
|
|
if isinstance(nested_value, (str, int, float, bool, type(None))):
|
|
sanitized[flat_key] = nested_value if nested_value is not None else "none"
|
|
elif isinstance(value, (str, int, float, bool, type(None))):
|
|
sanitized[key] = value if value is not None else "none"
|
|
else:
|
|
# Convert other types to string
|
|
sanitized[key] = str(value)
|
|
|
|
return sanitized
|
|
|
|
def get_memories(self, current_situation, n_matches=1):
|
|
"""Find matching recommendations using OpenAI embeddings"""
|
|
query_embedding = self.get_embedding(current_situation)
|
|
|
|
results = self.situation_collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=n_matches,
|
|
include=["metadatas", "documents", "distances"],
|
|
)
|
|
|
|
matched_results = []
|
|
for i in range(len(results["documents"][0])):
|
|
matched_results.append(
|
|
{
|
|
"matched_situation": results["documents"][0][i],
|
|
"recommendation": results["metadatas"][0][i]["recommendation"],
|
|
"similarity_score": 1 - results["distances"][0][i],
|
|
}
|
|
)
|
|
|
|
return matched_results
|
|
|
|
def get_memories_hybrid(
|
|
self,
|
|
current_situation: str,
|
|
signal_filters: Optional[Dict[str, Any]] = None,
|
|
n_matches: int = 3,
|
|
min_similarity: float = 0.5,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Hybrid search: Filter by structured signals, then rank by embedding similarity.
|
|
|
|
Args:
|
|
current_situation: Text description of current market situation
|
|
signal_filters: Dict of structured signals to filter by (e.g., {"unusual_volume": True})
|
|
Supports exact matches and can use dot notation for nested fields
|
|
e.g., {"structured_signals.unusual_volume": True}
|
|
n_matches: Number of results to return
|
|
min_similarity: Minimum similarity score (0-1) to include in results
|
|
|
|
Returns:
|
|
List of dicts with keys:
|
|
- matched_situation: Historical situation text
|
|
- recommendation: What was recommended
|
|
- similarity_score: Embedding similarity (0-1)
|
|
- metadata: Full metadata including outcome, signals, etc.
|
|
"""
|
|
query_embedding = self.get_embedding(current_situation)
|
|
|
|
# Build where clause for filtering
|
|
where_clause = None
|
|
if signal_filters:
|
|
where_clause = {}
|
|
for key, value in signal_filters.items():
|
|
where_clause[key] = value
|
|
|
|
# Query ChromaDB with optional filtering
|
|
query_params = {
|
|
"query_embeddings": [query_embedding],
|
|
"n_results": min(n_matches * 3, 100), # Get more results for filtering
|
|
"include": ["metadatas", "documents", "distances"],
|
|
}
|
|
|
|
if where_clause:
|
|
query_params["where"] = where_clause
|
|
|
|
results = self.situation_collection.query(**query_params)
|
|
|
|
# Process and filter results
|
|
matched_results = []
|
|
for i in range(len(results["documents"][0])):
|
|
similarity_score = 1 - results["distances"][0][i]
|
|
|
|
# Apply similarity threshold
|
|
if similarity_score < min_similarity:
|
|
continue
|
|
|
|
metadata = results["metadatas"][0][i]
|
|
|
|
matched_results.append(
|
|
{
|
|
"matched_situation": results["documents"][0][i],
|
|
"recommendation": metadata.get("recommendation", ""),
|
|
"similarity_score": similarity_score,
|
|
"metadata": metadata,
|
|
# Extract key fields for convenience
|
|
"ticker": metadata.get("ticker", ""),
|
|
"move_pct": metadata.get("move_pct", 0),
|
|
"move_direction": metadata.get("move_direction", ""),
|
|
"was_correct": metadata.get("was_correct", False),
|
|
"days_before_move": metadata.get("days_before_move", 0),
|
|
}
|
|
)
|
|
|
|
# Return top n_matches
|
|
return matched_results[:n_matches]
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""
|
|
Get statistics about the memory bank.
|
|
|
|
Returns:
|
|
Dict with keys:
|
|
- total_memories: Total number of stored memories
|
|
- accuracy_rate: % of memories where was_correct=True
|
|
- avg_move_pct: Average percentage move in stored outcomes
|
|
- signal_distribution: Count of different signal patterns
|
|
"""
|
|
total_count = self.situation_collection.count()
|
|
|
|
if total_count == 0:
|
|
return {
|
|
"total_memories": 0,
|
|
"accuracy_rate": 0.0,
|
|
"avg_move_pct": 0.0,
|
|
"signal_distribution": {},
|
|
}
|
|
|
|
# Get all memories
|
|
all_results = self.situation_collection.get(include=["metadatas"])
|
|
|
|
metadatas = all_results["metadatas"]
|
|
|
|
# Calculate statistics
|
|
correct_count = sum(1 for m in metadatas if m.get("was_correct") == True)
|
|
accuracy_rate = (correct_count / total_count * 100) if total_count > 0 else 0
|
|
|
|
move_pcts = [m.get("move_pct", 0) for m in metadatas if "move_pct" in m]
|
|
avg_move_pct = sum(move_pcts) / len(move_pcts) if move_pcts else 0
|
|
|
|
# Count signal patterns
|
|
signal_distribution = {}
|
|
for metadata in metadatas:
|
|
for key, value in metadata.items():
|
|
if key.startswith("structured_signals."):
|
|
signal_name = key.replace("structured_signals.", "")
|
|
if signal_name not in signal_distribution:
|
|
signal_distribution[signal_name] = {}
|
|
if value not in signal_distribution[signal_name]:
|
|
signal_distribution[signal_name][value] = 0
|
|
signal_distribution[signal_name][value] += 1
|
|
|
|
return {
|
|
"total_memories": total_count,
|
|
"accuracy_rate": accuracy_rate,
|
|
"avg_move_pct": avg_move_pct,
|
|
"signal_distribution": signal_distribution,
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage
|
|
matcher = FinancialSituationMemory()
|
|
|
|
# Example data
|
|
example_data = [
|
|
(
|
|
"High inflation rate with rising interest rates and declining consumer spending",
|
|
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
|
),
|
|
(
|
|
"Tech sector showing high volatility with increasing institutional selling pressure",
|
|
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
|
),
|
|
(
|
|
"Strong dollar affecting emerging markets with increasing forex volatility",
|
|
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
|
),
|
|
(
|
|
"Market showing signs of sector rotation with rising yields",
|
|
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
|
),
|
|
]
|
|
|
|
# Add the example situations and recommendations
|
|
matcher.add_situations(example_data)
|
|
|
|
# Example query
|
|
current_situation = """
|
|
Market showing increased volatility in tech sector, with institutional investors
|
|
reducing positions and rising interest rates affecting growth stock valuations
|
|
"""
|
|
|
|
try:
|
|
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
|
|
|
for i, rec in enumerate(recommendations, 1):
|
|
logger.info(f"Match {i}:")
|
|
logger.info(f"Similarity Score: {rec['similarity_score']:.2f}")
|
|
logger.info(f"Matched Situation: {rec['matched_situation']}")
|
|
logger.info(f"Recommendation: {rec['recommendation']}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during recommendation: {str(e)}")
|