TradingAgents/tradingagents/portfolio/persistence.py

599 lines
20 KiB
Python

"""
Portfolio state persistence for saving and loading portfolio data.
This module provides functionality to save and load portfolio state
to/from JSON files and SQLite databases, including trade history,
positions, and performance snapshots.
"""
import json
import sqlite3
from datetime import datetime
from decimal import Decimal
from pathlib import Path
from typing import Dict, Any, List, Optional
import logging
from tradingagents.security import sanitize_path_component
from .exceptions import PersistenceError, ValidationError
logger = logging.getLogger(__name__)
class PortfolioPersistence:
"""
Handles persistence of portfolio state to disk.
Supports both JSON file format for simple snapshots and SQLite
for more complex historical data and querying.
"""
def __init__(self, base_dir: Optional[str] = None):
"""
Initialize the persistence manager.
Args:
base_dir: Base directory for portfolio data (defaults to ./portfolio_data)
"""
self.base_dir = Path(base_dir) if base_dir else Path('./portfolio_data')
self.base_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Initialized PortfolioPersistence with base_dir={self.base_dir}")
def save_to_json(
self,
portfolio_data: Dict[str, Any],
filename: str
) -> None:
"""
Save portfolio state to a JSON file.
Args:
portfolio_data: Dictionary containing portfolio state
filename: Name of the file to save to
Raises:
PersistenceError: If save operation fails
ValidationError: If filename is invalid
"""
try:
# Sanitize filename
safe_filename = sanitize_path_component(filename)
if not safe_filename.endswith('.json'):
safe_filename += '.json'
filepath = self.base_dir / safe_filename
# Convert Decimal values to strings for JSON serialization
json_data = self._prepare_for_json(portfolio_data)
# Write to file with atomic operation
temp_filepath = filepath.with_suffix('.tmp')
with open(temp_filepath, 'w') as f:
json.dump(json_data, f, indent=2, default=str)
# Atomic rename
temp_filepath.replace(filepath)
logger.info(f"Saved portfolio state to {filepath}")
except (OSError, IOError, ValueError) as e:
raise PersistenceError(f"Failed to save portfolio to JSON: {e}")
def load_from_json(self, filename: str) -> Dict[str, Any]:
"""
Load portfolio state from a JSON file.
Args:
filename: Name of the file to load from
Returns:
Dictionary containing portfolio state
Raises:
PersistenceError: If load operation fails
ValidationError: If filename is invalid
"""
try:
# Sanitize filename
safe_filename = sanitize_path_component(filename)
if not safe_filename.endswith('.json'):
safe_filename += '.json'
filepath = self.base_dir / safe_filename
if not filepath.exists():
raise PersistenceError(f"Portfolio file not found: {filepath}")
with open(filepath, 'r') as f:
data = json.load(f)
# Convert string values back to Decimal where appropriate
data = self._restore_from_json(data)
logger.info(f"Loaded portfolio state from {filepath}")
return data
except (OSError, IOError, json.JSONDecodeError) as e:
raise PersistenceError(f"Failed to load portfolio from JSON: {e}")
def save_to_sqlite(
self,
portfolio_data: Dict[str, Any],
db_name: str = 'portfolio.db'
) -> None:
"""
Save portfolio state to SQLite database.
Creates tables if they don't exist and inserts/updates data.
Args:
portfolio_data: Dictionary containing portfolio state
db_name: Name of the SQLite database file
Raises:
PersistenceError: If save operation fails
"""
try:
# Sanitize database name
safe_db_name = sanitize_path_component(db_name)
if not safe_db_name.endswith('.db'):
safe_db_name += '.db'
db_path = self.base_dir / safe_db_name
with sqlite3.connect(db_path) as conn:
self._create_tables(conn)
self._insert_portfolio_snapshot(conn, portfolio_data)
self._insert_positions(conn, portfolio_data.get('positions', {}))
self._insert_trades(conn, portfolio_data.get('trade_history', []))
logger.info(f"Saved portfolio state to SQLite: {db_path}")
except (sqlite3.Error, OSError) as e:
raise PersistenceError(f"Failed to save portfolio to SQLite: {e}")
def load_from_sqlite(
self,
db_name: str = 'portfolio.db',
snapshot_id: Optional[int] = None
) -> Dict[str, Any]:
"""
Load portfolio state from SQLite database.
Args:
db_name: Name of the SQLite database file
snapshot_id: Specific snapshot ID to load (loads latest if None)
Returns:
Dictionary containing portfolio state
Raises:
PersistenceError: If load operation fails
"""
try:
# Sanitize database name
safe_db_name = sanitize_path_component(db_name)
if not safe_db_name.endswith('.db'):
safe_db_name += '.db'
db_path = self.base_dir / safe_db_name
if not db_path.exists():
raise PersistenceError(f"Database not found: {db_path}")
with sqlite3.connect(db_path) as conn:
conn.row_factory = sqlite3.Row
# Get snapshot
if snapshot_id is None:
# Get latest snapshot
cursor = conn.execute(
'SELECT * FROM portfolio_snapshots ORDER BY timestamp DESC LIMIT 1'
)
else:
cursor = conn.execute(
'SELECT * FROM portfolio_snapshots WHERE id = ?',
(snapshot_id,)
)
snapshot = cursor.fetchone()
if not snapshot:
raise PersistenceError("No portfolio snapshot found")
# Build portfolio data
portfolio_data = {
'cash': Decimal(snapshot['cash']),
'initial_capital': Decimal(snapshot['initial_capital']),
'commission_rate': Decimal(snapshot['commission_rate']),
'timestamp': snapshot['timestamp'],
}
# Load positions
portfolio_data['positions'] = self._load_positions(
conn, snapshot['id']
)
# Load trade history
portfolio_data['trade_history'] = self._load_trades(
conn, snapshot['id']
)
logger.info(f"Loaded portfolio state from SQLite: {db_path}")
return portfolio_data
except (sqlite3.Error, OSError) as e:
raise PersistenceError(f"Failed to load portfolio from SQLite: {e}")
def _create_tables(self, conn: sqlite3.Connection) -> None:
"""Create database tables if they don't exist."""
cursor = conn.cursor()
# Portfolio snapshots table
cursor.execute('''
CREATE TABLE IF NOT EXISTS portfolio_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
cash TEXT NOT NULL,
initial_capital TEXT NOT NULL,
commission_rate TEXT NOT NULL,
total_value TEXT,
metadata TEXT
)
''')
# Positions table
cursor.execute('''
CREATE TABLE IF NOT EXISTS positions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
snapshot_id INTEGER NOT NULL,
ticker TEXT NOT NULL,
quantity TEXT NOT NULL,
cost_basis TEXT NOT NULL,
sector TEXT,
opened_at TEXT NOT NULL,
last_updated TEXT NOT NULL,
stop_loss TEXT,
take_profit TEXT,
metadata TEXT,
FOREIGN KEY (snapshot_id) REFERENCES portfolio_snapshots (id)
)
''')
# Trade history table
cursor.execute('''
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
snapshot_id INTEGER NOT NULL,
ticker TEXT NOT NULL,
entry_date TEXT NOT NULL,
exit_date TEXT,
entry_price TEXT NOT NULL,
exit_price TEXT,
quantity TEXT NOT NULL,
pnl TEXT,
pnl_percent TEXT,
commission TEXT NOT NULL,
holding_period INTEGER,
is_win INTEGER,
FOREIGN KEY (snapshot_id) REFERENCES portfolio_snapshots (id)
)
''')
# Create indices for better query performance
cursor.execute(
'CREATE INDEX IF NOT EXISTS idx_positions_snapshot ON positions(snapshot_id)'
)
cursor.execute(
'CREATE INDEX IF NOT EXISTS idx_trades_snapshot ON trades(snapshot_id)'
)
cursor.execute(
'CREATE INDEX IF NOT EXISTS idx_trades_ticker ON trades(ticker)'
)
conn.commit()
def _insert_portfolio_snapshot(
self,
conn: sqlite3.Connection,
portfolio_data: Dict[str, Any]
) -> int:
"""Insert a portfolio snapshot and return its ID."""
cursor = conn.cursor()
cursor.execute('''
INSERT INTO portfolio_snapshots
(timestamp, cash, initial_capital, commission_rate, total_value, metadata)
VALUES (?, ?, ?, ?, ?, ?)
''', (
portfolio_data.get('timestamp', datetime.now().isoformat()),
str(portfolio_data.get('cash', '0')),
str(portfolio_data.get('initial_capital', '0')),
str(portfolio_data.get('commission_rate', '0')),
str(portfolio_data.get('total_value', '0')),
json.dumps(portfolio_data.get('metadata', {}))
))
conn.commit()
return cursor.lastrowid
def _insert_positions(
self,
conn: sqlite3.Connection,
positions: Dict[str, Dict[str, Any]]
) -> None:
"""Insert positions into the database."""
cursor = conn.cursor()
# Get the latest snapshot ID
snapshot_id = cursor.execute(
'SELECT MAX(id) FROM portfolio_snapshots'
).fetchone()[0]
for ticker, position_data in positions.items():
cursor.execute('''
INSERT INTO positions
(snapshot_id, ticker, quantity, cost_basis, sector, opened_at,
last_updated, stop_loss, take_profit, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
snapshot_id,
ticker,
str(position_data.get('quantity', '0')),
str(position_data.get('cost_basis', '0')),
position_data.get('sector'),
position_data.get('opened_at', datetime.now().isoformat()),
position_data.get('last_updated', datetime.now().isoformat()),
str(position_data.get('stop_loss')) if position_data.get('stop_loss') else None,
str(position_data.get('take_profit')) if position_data.get('take_profit') else None,
json.dumps(position_data.get('metadata', {}))
))
conn.commit()
def _insert_trades(
self,
conn: sqlite3.Connection,
trades: List[Dict[str, Any]]
) -> None:
"""Insert trades into the database."""
cursor = conn.cursor()
# Get the latest snapshot ID
snapshot_id = cursor.execute(
'SELECT MAX(id) FROM portfolio_snapshots'
).fetchone()[0]
for trade_data in trades:
cursor.execute('''
INSERT INTO trades
(snapshot_id, ticker, entry_date, exit_date, entry_price, exit_price,
quantity, pnl, pnl_percent, commission, holding_period, is_win)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
snapshot_id,
trade_data.get('ticker', ''),
trade_data.get('entry_date', ''),
trade_data.get('exit_date'),
str(trade_data.get('entry_price', '0')),
str(trade_data.get('exit_price')) if trade_data.get('exit_price') else None,
str(trade_data.get('quantity', '0')),
str(trade_data.get('pnl')) if trade_data.get('pnl') else None,
str(trade_data.get('pnl_percent')) if trade_data.get('pnl_percent') else None,
str(trade_data.get('commission', '0')),
trade_data.get('holding_period'),
1 if trade_data.get('is_win') else 0
))
conn.commit()
def _load_positions(
self,
conn: sqlite3.Connection,
snapshot_id: int
) -> Dict[str, Dict[str, Any]]:
"""Load positions from the database."""
cursor = conn.execute(
'SELECT * FROM positions WHERE snapshot_id = ?',
(snapshot_id,)
)
positions = {}
for row in cursor:
ticker = row['ticker']
positions[ticker] = {
'quantity': row['quantity'],
'cost_basis': row['cost_basis'],
'sector': row['sector'],
'opened_at': row['opened_at'],
'last_updated': row['last_updated'],
'stop_loss': row['stop_loss'],
'take_profit': row['take_profit'],
'metadata': json.loads(row['metadata']) if row['metadata'] else {}
}
return positions
def _load_trades(
self,
conn: sqlite3.Connection,
snapshot_id: int
) -> List[Dict[str, Any]]:
"""Load trades from the database."""
cursor = conn.execute(
'SELECT * FROM trades WHERE snapshot_id = ?',
(snapshot_id,)
)
trades = []
for row in cursor:
trades.append({
'ticker': row['ticker'],
'entry_date': row['entry_date'],
'exit_date': row['exit_date'],
'entry_price': row['entry_price'],
'exit_price': row['exit_price'],
'quantity': row['quantity'],
'pnl': row['pnl'],
'pnl_percent': row['pnl_percent'],
'commission': row['commission'],
'holding_period': row['holding_period'],
'is_win': bool(row['is_win'])
})
return trades
def _prepare_for_json(self, data: Any) -> Any:
"""Recursively prepare data for JSON serialization."""
if isinstance(data, Decimal):
return str(data)
elif isinstance(data, datetime):
return data.isoformat()
elif isinstance(data, dict):
return {k: self._prepare_for_json(v) for k, v in data.items()}
elif isinstance(data, list):
return [self._prepare_for_json(item) for item in data]
else:
return data
def _restore_from_json(self, data: Any) -> Any:
"""Recursively restore data types from JSON."""
if isinstance(data, dict):
# Check for known keys that should be Decimal
decimal_keys = {
'cash', 'initial_capital', 'commission_rate', 'quantity',
'cost_basis', 'stop_loss', 'take_profit', 'entry_price',
'exit_price', 'pnl', 'pnl_percent', 'commission', 'limit_price',
'stop_price', 'target_price', 'filled_price'
}
result = {}
for k, v in data.items():
if k in decimal_keys and v is not None:
try:
result[k] = Decimal(str(v))
except:
result[k] = v
else:
result[k] = self._restore_from_json(v)
return result
elif isinstance(data, list):
return [self._restore_from_json(item) for item in data]
else:
return data
def export_to_csv(
self,
trades: List[Dict[str, Any]],
filename: str
) -> None:
"""
Export trade history to CSV file.
Args:
trades: List of trade records
filename: Name of the CSV file
Raises:
PersistenceError: If export fails
"""
try:
import csv
safe_filename = sanitize_path_component(filename)
if not safe_filename.endswith('.csv'):
safe_filename += '.csv'
filepath = self.base_dir / safe_filename
if not trades:
logger.warning("No trades to export")
return
# Get all unique keys from trades
fieldnames = set()
for trade in trades:
fieldnames.update(trade.keys())
fieldnames = sorted(fieldnames)
with open(filepath, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(trades)
logger.info(f"Exported {len(trades)} trades to {filepath}")
except (OSError, IOError) as e:
raise PersistenceError(f"Failed to export to CSV: {e}")
def cleanup_old_snapshots(
self,
db_name: str = 'portfolio.db',
keep_last_n: int = 100
) -> int:
"""
Clean up old snapshots from the database.
Args:
db_name: Name of the SQLite database file
keep_last_n: Number of latest snapshots to keep
Returns:
Number of snapshots deleted
Raises:
PersistenceError: If cleanup fails
"""
try:
safe_db_name = sanitize_path_component(db_name)
if not safe_db_name.endswith('.db'):
safe_db_name += '.db'
db_path = self.base_dir / safe_db_name
if not db_path.exists():
return 0
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Get IDs of snapshots to delete
cursor.execute('''
SELECT id FROM portfolio_snapshots
ORDER BY timestamp DESC
LIMIT -1 OFFSET ?
''', (keep_last_n,))
ids_to_delete = [row[0] for row in cursor.fetchall()]
if not ids_to_delete:
return 0
# Delete related positions and trades
cursor.execute(
f'DELETE FROM positions WHERE snapshot_id IN ({",".join("?" * len(ids_to_delete))})',
ids_to_delete
)
cursor.execute(
f'DELETE FROM trades WHERE snapshot_id IN ({",".join("?" * len(ids_to_delete))})',
ids_to_delete
)
# Delete snapshots
cursor.execute(
f'DELETE FROM portfolio_snapshots WHERE id IN ({",".join("?" * len(ids_to_delete))})',
ids_to_delete
)
conn.commit()
logger.info(f"Deleted {len(ids_to_delete)} old snapshots")
return len(ids_to_delete)
except (sqlite3.Error, OSError) as e:
raise PersistenceError(f"Failed to cleanup old snapshots: {e}")