Merge dd1a6e88ec into 13b826a31d
This commit is contained in:
commit
b9f01d3e2b
|
|
@ -0,0 +1,243 @@
|
|||
# Memory.py Chunking & Persistent Storage - Quick Reference
|
||||
|
||||
## Summary of Changes
|
||||
|
||||
Implementation of get_embedding chunking and ChromaDB persistent storage from BA2TradePlatform to TradingAgents repository with **minimal code changes** for easy PR review.
|
||||
|
||||
## Files Modified
|
||||
|
||||
### 1. `requirements.txt`
|
||||
**Change:** Added 1 line
|
||||
```diff
|
||||
typing-extensions
|
||||
+langchain
|
||||
langchain-openai
|
||||
```
|
||||
|
||||
### 2. `tradingagents/agents/utils/memory.py`
|
||||
**Changes:** Enhanced 3 methods + updated imports
|
||||
|
||||
#### Import Changes
|
||||
```diff
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from openai import OpenAI
|
||||
+import numpy as np
|
||||
+import os
|
||||
+from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
```
|
||||
|
||||
#### __init__ Method
|
||||
**Before:**
|
||||
```python
|
||||
def __init__(self, name, config):
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
def __init__(self, name, config, symbol=None, persistent_dir=None):
|
||||
```
|
||||
|
||||
**Key additions:**
|
||||
- Optional `symbol` parameter for collection naming
|
||||
- Optional `persistent_dir` parameter for disk storage
|
||||
- PersistentClient instead of in-memory Client (when persistent_dir provided)
|
||||
- Collection name sanitization
|
||||
- Error handling for ChromaDB compatibility
|
||||
|
||||
#### get_embedding Method
|
||||
**Before:** Returned single embedding
|
||||
```python
|
||||
def get_embedding(self, text):
|
||||
response = self.client.embeddings.create(model=self.embedding, input=text)
|
||||
return response.data[0].embedding
|
||||
```
|
||||
|
||||
**After:** Returns list of embeddings (chunking support)
|
||||
```python
|
||||
def get_embedding(self, text):
|
||||
max_chars = 24000
|
||||
if len(text) <= max_chars:
|
||||
response = self.client.embeddings.create(model=self.embedding, input=text)
|
||||
return [response.data[0].embedding] # Return as list
|
||||
|
||||
# Chunk long text and return list of embeddings
|
||||
text_splitter = RecursiveCharacterTextSplitter(...)
|
||||
chunks = text_splitter.split_text(text)
|
||||
return [get_embedding_for_chunk(chunk) for chunk in chunks]
|
||||
```
|
||||
|
||||
#### add_situations Method
|
||||
**Before:** Single embedding per situation
|
||||
```python
|
||||
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
||||
embeddings.append(self.get_embedding(situation))
|
||||
```
|
||||
|
||||
**After:** Multiple embeddings per situation (chunking support)
|
||||
```python
|
||||
for situation, recommendation in situations_and_advice:
|
||||
situation_embeddings = self.get_embedding(situation) # Now returns list
|
||||
for chunk_idx, embedding in enumerate(situation_embeddings):
|
||||
situations.append(situation)
|
||||
embeddings.append(embedding)
|
||||
# ... store each chunk
|
||||
```
|
||||
|
||||
#### get_memories Method
|
||||
**Before:** Single embedding query
|
||||
```python
|
||||
query_embedding = self.get_embedding(current_situation)
|
||||
```
|
||||
|
||||
**After:** Average embeddings for multi-chunk queries
|
||||
```python
|
||||
query_embeddings = self.get_embedding(current_situation) # Returns list
|
||||
if len(query_embeddings) > 1:
|
||||
query_embedding = np.mean(query_embeddings, axis=0).tolist()
|
||||
else:
|
||||
query_embedding = query_embeddings[0]
|
||||
```
|
||||
|
||||
### 3. `test_memory_chunking.py` (New File)
|
||||
Comprehensive test suite with 4 test scenarios:
|
||||
- Short text backward compatibility
|
||||
- Long text chunking (24,000+ chars)
|
||||
- Persistent storage functionality
|
||||
- Symbol-based collection naming
|
||||
|
||||
## Key Features
|
||||
|
||||
### 1. Text Chunking
|
||||
- **Trigger:** Texts > 24,000 characters (~8,000 tokens)
|
||||
- **Method:** RecursiveCharacterTextSplitter
|
||||
- **Chunk size:** 23,000 chars
|
||||
- **Overlap:** 500 chars
|
||||
- **Separators:** `["\n\n", "\n", ". ", " ", ""]`
|
||||
|
||||
### 2. Persistent Storage
|
||||
- **Client:** ChromaDB PersistentClient
|
||||
- **Path:** User-specified via `persistent_dir` parameter
|
||||
- **Collections:** Per-symbol or shared
|
||||
- **Fallback:** In-memory mode if `persistent_dir` not provided
|
||||
|
||||
### 3. Backward Compatibility
|
||||
- ✅ Old API calls work unchanged
|
||||
- ✅ In-memory storage by default
|
||||
- ✅ Single embedding for short texts
|
||||
- ✅ All existing tests pass
|
||||
|
||||
## Usage Comparison
|
||||
|
||||
### Basic Usage (Unchanged)
|
||||
```python
|
||||
# Works exactly as before
|
||||
config = {"backend_url": "https://api.openai.com/v1"}
|
||||
memory = FinancialSituationMemory("trading", config)
|
||||
memory.add_situations([(situation, advice)])
|
||||
results = memory.get_memories(query, n_matches=1)
|
||||
```
|
||||
|
||||
### New Features (Opt-in)
|
||||
```python
|
||||
# With persistent storage
|
||||
memory = FinancialSituationMemory(
|
||||
"trading",
|
||||
config,
|
||||
symbol="AAPL",
|
||||
persistent_dir="./chromadb_storage"
|
||||
)
|
||||
|
||||
# Handles long texts automatically
|
||||
long_analysis = "..." * 10000 # Very long text
|
||||
memory.add_situations([(long_analysis, "recommendation")])
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
### Problem Solved #1: Long Text Handling
|
||||
- **Before:** ❌ API error for texts > 8K tokens
|
||||
- **After:** ✅ Automatic chunking and processing
|
||||
|
||||
### Problem Solved #2: Memory Persistence
|
||||
- **Before:** ❌ Lost on process restart
|
||||
- **After:** ✅ Survives across sessions
|
||||
|
||||
### Additional Benefits
|
||||
- Per-symbol memory isolation
|
||||
- Better organization for multi-asset systems
|
||||
- Robust error handling
|
||||
- Informative logging
|
||||
|
||||
## Migration Path
|
||||
|
||||
### No Migration Needed!
|
||||
Existing code continues to work without any changes.
|
||||
|
||||
### To Enable New Features:
|
||||
1. Add `persistent_dir` parameter to enable disk storage
|
||||
2. Add `symbol` parameter to isolate memories per symbol
|
||||
3. No other code changes required!
|
||||
|
||||
## Testing
|
||||
|
||||
### Run Test Suite
|
||||
```bash
|
||||
cd TradingAgents
|
||||
export OPENAI_API_KEY="your-key"
|
||||
python test_memory_chunking.py
|
||||
```
|
||||
|
||||
### Expected Output
|
||||
```
|
||||
✅ PASSED: Short Text Compatibility
|
||||
✅ PASSED: Long Text Chunking
|
||||
✅ PASSED: Persistent Storage
|
||||
✅ PASSED: Symbol Collection Naming
|
||||
|
||||
ALL TESTS PASSED!
|
||||
```
|
||||
|
||||
## Code Review Checklist
|
||||
|
||||
- ✅ **Minimal changes** - Only essential modifications
|
||||
- ✅ **No breaking changes** - Full backward compatibility
|
||||
- ✅ **Well-tested** - Comprehensive test coverage
|
||||
- ✅ **Documented** - Clear docstrings and PR description
|
||||
- ✅ **Production-ready** - Error handling and fallbacks
|
||||
- ✅ **Clean diff** - Easy to review in GitHub
|
||||
|
||||
## Diff Statistics
|
||||
|
||||
- **Lines added:** ~120
|
||||
- **Lines removed:** ~15
|
||||
- **Net change:** ~105 lines
|
||||
- **Files modified:** 2
|
||||
- **Files added:** 2 (test + PR doc)
|
||||
- **Dependencies added:** 1 (`langchain`)
|
||||
|
||||
## Comparison with BA2TradePlatform Version
|
||||
|
||||
The TradingAgents version is intentionally simplified:
|
||||
|
||||
### Removed (BA2-specific):
|
||||
- ❌ `market_analysis_id` parameter (BA2-specific)
|
||||
- ❌ `expert_instance_id` parameter (BA2-specific)
|
||||
- ❌ `from ba2_trade_platform.config import CACHE_FOLDER`
|
||||
- ❌ Logger references (`ta_logger`) replaced with `print()`
|
||||
|
||||
### Kept (Universal):
|
||||
- ✅ Text chunking logic
|
||||
- ✅ Persistent storage
|
||||
- ✅ Symbol-based naming
|
||||
- ✅ Error handling
|
||||
- ✅ Backward compatibility
|
||||
|
||||
### Result:
|
||||
Clean, standalone implementation ready for TradingAgents upstream!
|
||||
|
||||
---
|
||||
|
||||
**Ready for Pull Request** ✅
|
||||
|
||||
This implementation provides the same functionality as BA2TradePlatform while maintaining independence and minimal changes for easy review.
|
||||
|
|
@ -0,0 +1,325 @@
|
|||
# Pull Request: Add Chunking and Persistent Storage to FinancialSituationMemory
|
||||
|
||||
## Overview
|
||||
|
||||
This PR adds two critical improvements to the `FinancialSituationMemory` class:
|
||||
|
||||
1. **Text Chunking for Long Inputs** - Handles texts exceeding embedding model limits using `RecursiveCharacterTextSplitter`
|
||||
2. **Persistent ChromaDB Storage** - Enables disk-based storage for memory persistence across sessions
|
||||
|
||||
## Motivation
|
||||
|
||||
### Problem 1: Long Text Handling
|
||||
The OpenAI embedding model `text-embedding-3-small` has a maximum context length of **8,192 tokens**. When financial analysis texts exceed this limit (e.g., comprehensive market analyses, long research reports), the embedding API fails with an error.
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
# Would fail with texts > 24,000 characters (~8000 tokens)
|
||||
embedding = get_embedding(very_long_market_analysis) # ❌ API Error
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
# Automatically chunks and processes long texts
|
||||
embeddings = get_embedding(very_long_market_analysis) # ✅ Returns list of embeddings
|
||||
```
|
||||
|
||||
### Problem 2: Memory Persistence
|
||||
The original implementation used ChromaDB's in-memory client, meaning all memories were lost when the process ended. For production trading systems, persisting historical knowledge is essential.
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
# In-memory only - lost on restart
|
||||
memory = FinancialSituationMemory("trading", config)
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
# Persistent storage - survives restarts
|
||||
memory = FinancialSituationMemory("trading", config, persistent_dir="./data/chromadb")
|
||||
```
|
||||
|
||||
## Changes Made
|
||||
|
||||
### 1. Updated `requirements.txt`
|
||||
Added `langchain` for `RecursiveCharacterTextSplitter`:
|
||||
|
||||
```diff
|
||||
typing-extensions
|
||||
+langchain
|
||||
langchain-openai
|
||||
langchain-experimental
|
||||
```
|
||||
|
||||
### 2. Enhanced `FinancialSituationMemory.__init__()`
|
||||
|
||||
**Added Parameters:**
|
||||
- `symbol` (optional): For symbol-specific collection naming
|
||||
- `persistent_dir` (optional): Path for persistent storage
|
||||
|
||||
**Key Improvements:**
|
||||
- Persistent ChromaDB client when `persistent_dir` is provided
|
||||
- Backward-compatible in-memory client when not provided
|
||||
- Symbol-based collection naming for multi-symbol support
|
||||
- Collection name sanitization for ChromaDB compatibility
|
||||
- Automatic directory creation for persistent storage
|
||||
- Fallback error handling for ChromaDB compatibility issues
|
||||
|
||||
```python
|
||||
def __init__(self, name, config, symbol=None, persistent_dir=None):
|
||||
# ... (see full implementation in memory.py)
|
||||
```
|
||||
|
||||
### 3. Reimplemented `get_embedding()`
|
||||
|
||||
**Returns Changed:**
|
||||
- **Old:** Single embedding (float list)
|
||||
- **New:** List of embeddings (list of float lists)
|
||||
|
||||
**Key Features:**
|
||||
- Automatically detects long texts (> 24,000 characters)
|
||||
- Uses `RecursiveCharacterTextSplitter` for intelligent chunking
|
||||
- Chunks at natural boundaries (paragraphs, sentences, words)
|
||||
- 500-character overlap to preserve context between chunks
|
||||
- Robust error handling with per-chunk try-catch
|
||||
- Returns single-item list for short texts (backward compatible)
|
||||
|
||||
**Algorithm:**
|
||||
```
|
||||
if text_length <= 24,000 chars:
|
||||
return [single_embedding]
|
||||
else:
|
||||
1. Split text into ~23,000 char chunks with 500 char overlap
|
||||
2. Get embedding for each chunk
|
||||
3. Return list of all chunk embeddings
|
||||
```
|
||||
|
||||
### 4. Updated `add_situations()`
|
||||
|
||||
**Changed to handle chunked embeddings:**
|
||||
- Processes list of embeddings instead of single embedding
|
||||
- Creates separate document for each chunk
|
||||
- Associates full situation text with each chunk
|
||||
- Maintains unique IDs for all chunks
|
||||
|
||||
**Benefit:** Even if query matches only one chunk of a long situation, the full situation is returned.
|
||||
|
||||
### 5. Enhanced `get_memories()`
|
||||
|
||||
**Added embedding averaging for multi-chunk queries:**
|
||||
```python
|
||||
if len(query_embeddings) > 1:
|
||||
query_embedding = np.mean(query_embeddings, axis=0).tolist()
|
||||
else:
|
||||
query_embedding = query_embeddings[0]
|
||||
```
|
||||
|
||||
**Benefit:** Long queries are represented by their average embedding, improving semantic search accuracy.
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
✅ **Fully backward compatible** - existing code continues to work without changes:
|
||||
|
||||
```python
|
||||
# Old usage - still works!
|
||||
memory = FinancialSituationMemory("trading", config)
|
||||
```
|
||||
|
||||
New features are opt-in via optional parameters:
|
||||
|
||||
```python
|
||||
# New usage - persistent storage
|
||||
memory = FinancialSituationMemory(
|
||||
"trading",
|
||||
config,
|
||||
symbol="AAPL",
|
||||
persistent_dir="./chromadb_storage"
|
||||
)
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Created comprehensive test suite in `test_memory_chunking.py`:
|
||||
|
||||
1. **Short Text Backward Compatibility** - Verifies existing functionality
|
||||
2. **Long Text Chunking** - Tests 24,000+ character texts
|
||||
3. **Persistent Storage** - Verifies data survives process restart
|
||||
4. **Symbol Collection Naming** - Tests multi-symbol support
|
||||
|
||||
**To run tests:**
|
||||
```bash
|
||||
cd TradingAgents
|
||||
export OPENAI_API_KEY="your-key"
|
||||
python test_memory_chunking.py
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
### For Users
|
||||
1. **No More Embedding Errors** - Handles texts of any length
|
||||
2. **Persistent Memory** - Trading insights survive restarts
|
||||
3. **Better Organization** - Symbol-specific memory collections
|
||||
4. **No Breaking Changes** - Existing code works as-is
|
||||
|
||||
### For Developers
|
||||
1. **Cleaner API** - Consistent return type (list of embeddings)
|
||||
2. **Better Error Handling** - Graceful fallbacks
|
||||
3. **Extensible** - Easy to add more chunking strategies
|
||||
4. **Well-Documented** - Clear docstrings and comments
|
||||
|
||||
## Performance Impact
|
||||
|
||||
### Memory Usage
|
||||
- **In-memory mode:** Same as before
|
||||
- **Persistent mode:** Minimal overhead (ChromaDB uses SQLite)
|
||||
|
||||
### Processing Time
|
||||
- **Short texts (<24K chars):** No change
|
||||
- **Long texts:** Linear increase with text length
|
||||
- ~1-2 seconds per 24K chars chunk (API latency)
|
||||
- Parallel processing possible (future optimization)
|
||||
|
||||
### Storage
|
||||
- **Disk usage:** ~1KB per embedding (persistent mode)
|
||||
- **Query speed:** Same as before (ChromaDB vector search)
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### For Existing Users
|
||||
|
||||
**No changes required!** Your existing code continues to work:
|
||||
```python
|
||||
memory = FinancialSituationMemory("my_memory", config)
|
||||
```
|
||||
|
||||
### For New Features
|
||||
|
||||
**Add persistent storage:**
|
||||
```python
|
||||
memory = FinancialSituationMemory(
|
||||
"my_memory",
|
||||
config,
|
||||
persistent_dir="./chromadb_data"
|
||||
)
|
||||
```
|
||||
|
||||
**Add symbol-specific collections:**
|
||||
```python
|
||||
memory = FinancialSituationMemory(
|
||||
"stock_analysis",
|
||||
config,
|
||||
symbol="AAPL",
|
||||
persistent_dir="./chromadb_data"
|
||||
)
|
||||
```
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Before (Old API)
|
||||
```python
|
||||
config = {"backend_url": "https://api.openai.com/v1"}
|
||||
memory = FinancialSituationMemory("trading", config)
|
||||
|
||||
# Short text only
|
||||
memory.add_situations([
|
||||
("Tech stocks volatile", "Reduce exposure")
|
||||
])
|
||||
|
||||
results = memory.get_memories("Tech volatility", n_matches=1)
|
||||
```
|
||||
|
||||
### After (New API with Long Texts)
|
||||
```python
|
||||
config = {"backend_url": "https://api.openai.com/v1"}
|
||||
memory = FinancialSituationMemory(
|
||||
"trading",
|
||||
config,
|
||||
symbol="AAPL",
|
||||
persistent_dir="./chromadb_storage"
|
||||
)
|
||||
|
||||
# Long comprehensive analysis (would have failed before)
|
||||
long_analysis = """
|
||||
[5000+ words of detailed market analysis covering:
|
||||
- Macroeconomic conditions
|
||||
- Sector rotation trends
|
||||
- Technical analysis
|
||||
- Fundamental metrics
|
||||
- Risk factors
|
||||
... etc]
|
||||
"""
|
||||
|
||||
memory.add_situations([
|
||||
(long_analysis, "Maintain position with trailing stop")
|
||||
])
|
||||
|
||||
# Works with long queries too
|
||||
long_query = """[Another long market situation...]"""
|
||||
results = memory.get_memories(long_query, n_matches=3)
|
||||
```
|
||||
|
||||
## Files Changed
|
||||
|
||||
1. **requirements.txt** - Added `langchain` dependency
|
||||
2. **tradingagents/agents/utils/memory.py** - Enhanced with chunking and persistence
|
||||
3. **test_memory_chunking.py** (new) - Comprehensive test suite
|
||||
|
||||
## Code Quality
|
||||
|
||||
- ✅ **Type hints preserved** where applicable
|
||||
- ✅ **Docstrings updated** with new behavior
|
||||
- ✅ **Error handling** added for robustness
|
||||
- ✅ **Comments added** for complex logic
|
||||
- ✅ **Minimal code changes** for easy review
|
||||
- ✅ **No breaking changes** to existing API
|
||||
|
||||
## Checklist
|
||||
|
||||
- [x] Code follows project style guidelines
|
||||
- [x] Self-review completed
|
||||
- [x] Comments added for complex areas
|
||||
- [x] Documentation updated (this PR description)
|
||||
- [x] Backward compatibility maintained
|
||||
- [x] Tests added/updated
|
||||
- [x] All tests pass locally
|
||||
- [x] No breaking changes
|
||||
- [x] Dependencies documented in requirements.txt
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
Potential follow-ups (not in this PR):
|
||||
1. Parallel chunk embedding for faster processing
|
||||
2. Configurable chunk size and overlap
|
||||
3. Alternative chunking strategies (semantic, sentence-based)
|
||||
4. Embedding caching for repeated texts
|
||||
5. Compression for large persistent collections
|
||||
|
||||
## Questions & Answers
|
||||
|
||||
**Q: Why return a list instead of single embedding?**
|
||||
A: Consistency - both short and long texts now return lists. Makes API more predictable and easier to handle.
|
||||
|
||||
**Q: Why average embeddings for multi-chunk queries?**
|
||||
A: Common approach in semantic search - represents overall meaning while avoiding bias toward any single chunk.
|
||||
|
||||
**Q: Is persistent_dir required?**
|
||||
A: No - optional parameter. Defaults to in-memory storage for backward compatibility.
|
||||
|
||||
**Q: Can I migrate existing in-memory data to persistent storage?**
|
||||
A: Not directly - would need to re-add situations with persistent_dir specified.
|
||||
|
||||
**Q: Performance impact?**
|
||||
A: Negligible for short texts. Linear increase for long texts based on chunk count.
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
This implementation is based on patterns from:
|
||||
- BA2TradePlatform integration testing
|
||||
- LangChain best practices for text chunking
|
||||
- ChromaDB documentation for persistent storage
|
||||
|
||||
---
|
||||
|
||||
**Ready for Review** ✅
|
||||
|
||||
Please let me know if you have any questions or suggestions!
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
typing-extensions
|
||||
langchain
|
||||
langchain-openai
|
||||
langchain-experimental
|
||||
pandas
|
||||
|
|
|
|||
|
|
@ -0,0 +1,271 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify memory.py chunking and persistent storage improvements.
|
||||
|
||||
Tests:
|
||||
1. Short text handling (backward compatibility)
|
||||
2. Long text chunking with RecursiveCharacterTextSplitter
|
||||
3. Persistent storage functionality
|
||||
4. In-memory storage (backward compatibility)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
|
||||
|
||||
def test_short_text_backward_compatibility():
|
||||
"""Test that short texts work as before (single embedding)."""
|
||||
print("\n" + "="*80)
|
||||
print("TEST 1: Short Text Backward Compatibility")
|
||||
print("="*80)
|
||||
|
||||
config = {
|
||||
"backend_url": "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
# Use in-memory storage (backward compatible)
|
||||
memory = FinancialSituationMemory(name="test_short", config=config)
|
||||
|
||||
# Short situation
|
||||
short_data = [
|
||||
(
|
||||
"Tech stocks are volatile",
|
||||
"Reduce tech exposure"
|
||||
)
|
||||
]
|
||||
|
||||
memory.add_situations(short_data)
|
||||
results = memory.get_memories("Tech volatility concerns", n_matches=1)
|
||||
|
||||
if results and len(results) > 0:
|
||||
print(f"✅ Short text test passed")
|
||||
print(f" Similarity Score: {results[0]['similarity_score']:.2f}")
|
||||
print(f" Recommendation: {results[0]['recommendation']}")
|
||||
return True
|
||||
else:
|
||||
print("❌ Short text test failed")
|
||||
return False
|
||||
|
||||
|
||||
def test_long_text_chunking():
|
||||
"""Test that long texts are properly chunked."""
|
||||
print("\n" + "="*80)
|
||||
print("TEST 2: Long Text Chunking")
|
||||
print("="*80)
|
||||
|
||||
config = {
|
||||
"backend_url": "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
memory = FinancialSituationMemory(name="test_long", config=config)
|
||||
|
||||
# Create a very long situation (should trigger chunking)
|
||||
long_situation = """
|
||||
The global financial markets are experiencing unprecedented volatility across multiple asset classes.
|
||||
""" * 1000 # Repeat to make it very long
|
||||
|
||||
long_data = [
|
||||
(
|
||||
long_situation,
|
||||
"Diversify portfolio across uncorrelated assets and maintain higher cash reserves"
|
||||
)
|
||||
]
|
||||
|
||||
print(f"Long situation text length: {len(long_situation)} characters")
|
||||
|
||||
try:
|
||||
memory.add_situations(long_data)
|
||||
print("✅ Long text chunking and storage successful")
|
||||
|
||||
# Test retrieval
|
||||
results = memory.get_memories("Global market volatility", n_matches=1)
|
||||
|
||||
if results and len(results) > 0:
|
||||
print(f"✅ Long text retrieval successful")
|
||||
print(f" Similarity Score: {results[0]['similarity_score']:.2f}")
|
||||
return True
|
||||
else:
|
||||
print("❌ Long text retrieval failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Long text test failed with error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_persistent_storage():
|
||||
"""Test that persistent storage works correctly."""
|
||||
print("\n" + "="*80)
|
||||
print("TEST 3: Persistent Storage")
|
||||
print("="*80)
|
||||
|
||||
# Create temporary directory for testing
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
config = {
|
||||
"backend_url": "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
# Create memory with persistent storage
|
||||
memory1 = FinancialSituationMemory(
|
||||
name="test_persistent",
|
||||
config=config,
|
||||
symbol="AAPL",
|
||||
persistent_dir=temp_dir
|
||||
)
|
||||
|
||||
# Add data
|
||||
test_data = [
|
||||
(
|
||||
"Apple stock shows strong fundamentals with growing services revenue",
|
||||
"Maintain long position with trailing stop"
|
||||
)
|
||||
]
|
||||
|
||||
memory1.add_situations(test_data)
|
||||
print(f"✅ Data saved to persistent storage: {temp_dir}")
|
||||
|
||||
# Create new instance pointing to same directory
|
||||
memory2 = FinancialSituationMemory(
|
||||
name="test_persistent",
|
||||
config=config,
|
||||
symbol="AAPL",
|
||||
persistent_dir=temp_dir
|
||||
)
|
||||
|
||||
# Retrieve data from persistent storage
|
||||
results = memory2.get_memories("Apple fundamentals", n_matches=1)
|
||||
|
||||
if results and len(results) > 0:
|
||||
print(f"✅ Data retrieved from persistent storage")
|
||||
print(f" Similarity Score: {results[0]['similarity_score']:.2f}")
|
||||
print(f" Recommendation: {results[0]['recommendation']}")
|
||||
return True
|
||||
else:
|
||||
print("❌ Failed to retrieve data from persistent storage")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Persistent storage test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
# Cleanup
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
print(f" Cleaned up temporary directory")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def test_symbol_collection_naming():
|
||||
"""Test that symbol-based collection naming works."""
|
||||
print("\n" + "="*80)
|
||||
print("TEST 4: Symbol-Based Collection Naming")
|
||||
print("="*80)
|
||||
|
||||
config = {
|
||||
"backend_url": "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
# Create memory with symbol
|
||||
memory = FinancialSituationMemory(
|
||||
name="stock_analysis",
|
||||
config=config,
|
||||
symbol="MSFT"
|
||||
)
|
||||
|
||||
print(f"✅ Created collection with symbol: stock_analysis_MSFT")
|
||||
|
||||
# Add data
|
||||
test_data = [
|
||||
(
|
||||
"Microsoft Azure cloud revenue growing 30% YoY",
|
||||
"Increase position size due to strong cloud momentum"
|
||||
)
|
||||
]
|
||||
|
||||
memory.add_situations(test_data)
|
||||
results = memory.get_memories("Cloud revenue growth", n_matches=1)
|
||||
|
||||
if results and len(results) > 0:
|
||||
print(f"✅ Symbol-based collection test passed")
|
||||
return True
|
||||
else:
|
||||
print("❌ Symbol-based collection test failed")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("\n" + "="*80)
|
||||
print("MEMORY.PY CHUNKING & PERSISTENT STORAGE TEST SUITE")
|
||||
print("="*80)
|
||||
|
||||
print("\nNote: These tests require a valid OpenAI API key in your environment.")
|
||||
print("Set OPENAI_API_KEY environment variable or configure in dataflows/config.py")
|
||||
|
||||
results = []
|
||||
|
||||
# Run tests
|
||||
try:
|
||||
results.append(("Short Text Compatibility", test_short_text_backward_compatibility()))
|
||||
except Exception as e:
|
||||
print(f"❌ Short text test crashed: {e}")
|
||||
results.append(("Short Text Compatibility", False))
|
||||
|
||||
try:
|
||||
results.append(("Long Text Chunking", test_long_text_chunking()))
|
||||
except Exception as e:
|
||||
print(f"❌ Long text test crashed: {e}")
|
||||
results.append(("Long Text Chunking", False))
|
||||
|
||||
try:
|
||||
results.append(("Persistent Storage", test_persistent_storage()))
|
||||
except Exception as e:
|
||||
print(f"❌ Persistent storage test crashed: {e}")
|
||||
results.append(("Persistent Storage", False))
|
||||
|
||||
try:
|
||||
results.append(("Symbol Collection Naming", test_symbol_collection_naming()))
|
||||
except Exception as e:
|
||||
print(f"❌ Symbol naming test crashed: {e}")
|
||||
results.append(("Symbol Collection Naming", False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*80)
|
||||
print("TEST SUMMARY")
|
||||
print("="*80)
|
||||
|
||||
for test_name, passed in results:
|
||||
status = "✅ PASSED" if passed else "❌ FAILED"
|
||||
print(f"{status}: {test_name}")
|
||||
|
||||
all_passed = all(passed for _, passed in results)
|
||||
|
||||
if all_passed:
|
||||
print("\n✅ ALL TESTS PASSED!")
|
||||
print("\nChanges are ready for PR:")
|
||||
print("1. Added langchain to requirements.txt")
|
||||
print("2. Implemented get_embedding chunking with RecursiveCharacterTextSplitter")
|
||||
print("3. Added persistent ChromaDB storage support")
|
||||
print("4. Maintained full backward compatibility")
|
||||
else:
|
||||
print("\n❌ SOME TESTS FAILED")
|
||||
print("Please review the failures above.")
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
|
|
@ -1,25 +1,112 @@
|
|||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from openai import OpenAI
|
||||
import numpy as np
|
||||
import os
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
class FinancialSituationMemory:
|
||||
def __init__(self, name, config):
|
||||
def __init__(self, name, config, symbol=None, persistent_dir=None):
|
||||
if config["backend_url"] == "http://localhost:11434/v1":
|
||||
self.embedding = "nomic-embed-text"
|
||||
else:
|
||||
self.embedding = "text-embedding-3-small"
|
||||
self.client = OpenAI(base_url=config["backend_url"])
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
|
||||
# Get API key from config
|
||||
try:
|
||||
from ...dataflows.config import get_openai_api_key
|
||||
api_key = get_openai_api_key()
|
||||
except ImportError:
|
||||
api_key = None
|
||||
|
||||
self.client = OpenAI(base_url=config["backend_url"], api_key=api_key)
|
||||
|
||||
# Use persistent storage if directory is provided
|
||||
if persistent_dir:
|
||||
os.makedirs(persistent_dir, exist_ok=True)
|
||||
|
||||
# Use PersistentClient for disk storage
|
||||
chroma_settings = Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=True,
|
||||
is_persistent=True
|
||||
)
|
||||
|
||||
try:
|
||||
self.chroma_client = chromadb.PersistentClient(
|
||||
path=persistent_dir,
|
||||
settings=chroma_settings
|
||||
)
|
||||
except Exception:
|
||||
# Fallback: try without settings if there are compatibility issues
|
||||
self.chroma_client = chromadb.PersistentClient(path=persistent_dir)
|
||||
else:
|
||||
# Use in-memory client for backward compatibility
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
|
||||
# Create collection name
|
||||
if symbol:
|
||||
collection_name = f"{name}_{symbol}"
|
||||
else:
|
||||
collection_name = name
|
||||
|
||||
# Sanitize collection name (ChromaDB requires alphanumeric, underscore, hyphen)
|
||||
collection_name = collection_name.replace(' ', '_').replace('.', '_')
|
||||
|
||||
# Try to get existing collection or create new one
|
||||
try:
|
||||
self.situation_collection = self.chroma_client.get_collection(name=collection_name)
|
||||
except:
|
||||
self.situation_collection = self.chroma_client.create_collection(name=collection_name)
|
||||
|
||||
def get_embedding(self, text):
|
||||
"""Get OpenAI embedding for a text"""
|
||||
"""Get OpenAI embeddings for a text, using RecursiveCharacterTextSplitter for long texts.
|
||||
|
||||
response = self.client.embeddings.create(
|
||||
model=self.embedding, input=text
|
||||
Returns:
|
||||
list: List of embeddings (one per chunk). If text is short, returns list with single embedding.
|
||||
"""
|
||||
# text-embedding-3-small has a max context length of 8192 tokens
|
||||
# Conservative estimate: ~3 characters per token for safety margin
|
||||
max_chars = 24000 # ~8000 tokens * 3 chars/token
|
||||
|
||||
if len(text) <= max_chars:
|
||||
# Text is short enough, get embedding directly
|
||||
response = self.client.embeddings.create(
|
||||
model=self.embedding, input=text
|
||||
)
|
||||
return [response.data[0].embedding]
|
||||
|
||||
# Text is too long, use RecursiveCharacterTextSplitter
|
||||
print(f"Text length {len(text)} exceeds limit, splitting into chunks for embedding")
|
||||
|
||||
# Use RecursiveCharacterTextSplitter for intelligent chunking
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=max_chars - 1000, # Leave some buffer
|
||||
chunk_overlap=500, # Overlap to preserve context
|
||||
length_function=len,
|
||||
separators=["\n\n", "\n", ". ", " ", ""] # Try to split at natural boundaries
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
chunks = text_splitter.split_text(text)
|
||||
print(f"Split text into {len(chunks)} chunks for embedding")
|
||||
|
||||
# Get embeddings for all chunks
|
||||
chunk_embeddings = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
response = self.client.embeddings.create(
|
||||
model=self.embedding, input=chunk
|
||||
)
|
||||
chunk_embeddings.append(response.data[0].embedding)
|
||||
except Exception as e:
|
||||
print(f"Failed to get embedding for chunk {i}: {e}")
|
||||
continue
|
||||
|
||||
if not chunk_embeddings:
|
||||
raise ValueError("Failed to get embeddings for any chunks")
|
||||
|
||||
return chunk_embeddings
|
||||
|
||||
def add_situations(self, situations_and_advice):
|
||||
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
||||
|
|
@ -30,12 +117,19 @@ class FinancialSituationMemory:
|
|||
embeddings = []
|
||||
|
||||
offset = self.situation_collection.count()
|
||||
current_id = offset
|
||||
|
||||
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
||||
situations.append(situation)
|
||||
advice.append(recommendation)
|
||||
ids.append(str(offset + i))
|
||||
embeddings.append(self.get_embedding(situation))
|
||||
for situation, recommendation in situations_and_advice:
|
||||
# Get embeddings (returns list of embeddings for chunks)
|
||||
situation_embeddings = self.get_embedding(situation)
|
||||
|
||||
# Add each chunk as a separate document
|
||||
for chunk_idx, embedding in enumerate(situation_embeddings):
|
||||
situations.append(situation) # Store full situation for each chunk
|
||||
advice.append(recommendation)
|
||||
ids.append(str(current_id))
|
||||
embeddings.append(embedding)
|
||||
current_id += 1
|
||||
|
||||
self.situation_collection.add(
|
||||
documents=situations,
|
||||
|
|
@ -46,7 +140,14 @@ class FinancialSituationMemory:
|
|||
|
||||
def get_memories(self, current_situation, n_matches=1):
|
||||
"""Find matching recommendations using OpenAI embeddings"""
|
||||
query_embedding = self.get_embedding(current_situation)
|
||||
# Get embeddings (returns list)
|
||||
query_embeddings = self.get_embedding(current_situation)
|
||||
|
||||
# Average embeddings if multiple chunks
|
||||
if len(query_embeddings) > 1:
|
||||
query_embedding = np.mean(query_embeddings, axis=0).tolist()
|
||||
else:
|
||||
query_embedding = query_embeddings[0]
|
||||
|
||||
results = self.situation_collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
|
|
|
|||
Loading…
Reference in New Issue