599 lines
20 KiB
Python
599 lines
20 KiB
Python
"""
|
|
Portfolio state persistence for saving and loading portfolio data.
|
|
|
|
This module provides functionality to save and load portfolio state
|
|
to/from JSON files and SQLite databases, including trade history,
|
|
positions, and performance snapshots.
|
|
"""
|
|
|
|
import json
|
|
import sqlite3
|
|
from datetime import datetime
|
|
from decimal import Decimal
|
|
from pathlib import Path
|
|
from typing import Dict, Any, List, Optional
|
|
import logging
|
|
|
|
from tradingagents.security import sanitize_path_component
|
|
from .exceptions import PersistenceError, ValidationError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PortfolioPersistence:
|
|
"""
|
|
Handles persistence of portfolio state to disk.
|
|
|
|
Supports both JSON file format for simple snapshots and SQLite
|
|
for more complex historical data and querying.
|
|
"""
|
|
|
|
def __init__(self, base_dir: Optional[str] = None):
|
|
"""
|
|
Initialize the persistence manager.
|
|
|
|
Args:
|
|
base_dir: Base directory for portfolio data (defaults to ./portfolio_data)
|
|
"""
|
|
self.base_dir = Path(base_dir) if base_dir else Path('./portfolio_data')
|
|
self.base_dir.mkdir(parents=True, exist_ok=True)
|
|
logger.info(f"Initialized PortfolioPersistence with base_dir={self.base_dir}")
|
|
|
|
def save_to_json(
|
|
self,
|
|
portfolio_data: Dict[str, Any],
|
|
filename: str
|
|
) -> None:
|
|
"""
|
|
Save portfolio state to a JSON file.
|
|
|
|
Args:
|
|
portfolio_data: Dictionary containing portfolio state
|
|
filename: Name of the file to save to
|
|
|
|
Raises:
|
|
PersistenceError: If save operation fails
|
|
ValidationError: If filename is invalid
|
|
"""
|
|
try:
|
|
# Sanitize filename
|
|
safe_filename = sanitize_path_component(filename)
|
|
if not safe_filename.endswith('.json'):
|
|
safe_filename += '.json'
|
|
|
|
filepath = self.base_dir / safe_filename
|
|
|
|
# Convert Decimal values to strings for JSON serialization
|
|
json_data = self._prepare_for_json(portfolio_data)
|
|
|
|
# Write to file with atomic operation
|
|
temp_filepath = filepath.with_suffix('.tmp')
|
|
with open(temp_filepath, 'w') as f:
|
|
json.dump(json_data, f, indent=2, default=str)
|
|
|
|
# Atomic rename
|
|
temp_filepath.replace(filepath)
|
|
|
|
logger.info(f"Saved portfolio state to {filepath}")
|
|
|
|
except (OSError, IOError, ValueError) as e:
|
|
raise PersistenceError(f"Failed to save portfolio to JSON: {e}")
|
|
|
|
def load_from_json(self, filename: str) -> Dict[str, Any]:
|
|
"""
|
|
Load portfolio state from a JSON file.
|
|
|
|
Args:
|
|
filename: Name of the file to load from
|
|
|
|
Returns:
|
|
Dictionary containing portfolio state
|
|
|
|
Raises:
|
|
PersistenceError: If load operation fails
|
|
ValidationError: If filename is invalid
|
|
"""
|
|
try:
|
|
# Sanitize filename
|
|
safe_filename = sanitize_path_component(filename)
|
|
if not safe_filename.endswith('.json'):
|
|
safe_filename += '.json'
|
|
|
|
filepath = self.base_dir / safe_filename
|
|
|
|
if not filepath.exists():
|
|
raise PersistenceError(f"Portfolio file not found: {filepath}")
|
|
|
|
with open(filepath, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
# Convert string values back to Decimal where appropriate
|
|
data = self._restore_from_json(data)
|
|
|
|
logger.info(f"Loaded portfolio state from {filepath}")
|
|
|
|
return data
|
|
|
|
except (OSError, IOError, json.JSONDecodeError) as e:
|
|
raise PersistenceError(f"Failed to load portfolio from JSON: {e}")
|
|
|
|
def save_to_sqlite(
|
|
self,
|
|
portfolio_data: Dict[str, Any],
|
|
db_name: str = 'portfolio.db'
|
|
) -> None:
|
|
"""
|
|
Save portfolio state to SQLite database.
|
|
|
|
Creates tables if they don't exist and inserts/updates data.
|
|
|
|
Args:
|
|
portfolio_data: Dictionary containing portfolio state
|
|
db_name: Name of the SQLite database file
|
|
|
|
Raises:
|
|
PersistenceError: If save operation fails
|
|
"""
|
|
try:
|
|
# Sanitize database name
|
|
safe_db_name = sanitize_path_component(db_name)
|
|
if not safe_db_name.endswith('.db'):
|
|
safe_db_name += '.db'
|
|
|
|
db_path = self.base_dir / safe_db_name
|
|
|
|
with sqlite3.connect(db_path) as conn:
|
|
self._create_tables(conn)
|
|
self._insert_portfolio_snapshot(conn, portfolio_data)
|
|
self._insert_positions(conn, portfolio_data.get('positions', {}))
|
|
self._insert_trades(conn, portfolio_data.get('trade_history', []))
|
|
|
|
logger.info(f"Saved portfolio state to SQLite: {db_path}")
|
|
|
|
except (sqlite3.Error, OSError) as e:
|
|
raise PersistenceError(f"Failed to save portfolio to SQLite: {e}")
|
|
|
|
def load_from_sqlite(
|
|
self,
|
|
db_name: str = 'portfolio.db',
|
|
snapshot_id: Optional[int] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Load portfolio state from SQLite database.
|
|
|
|
Args:
|
|
db_name: Name of the SQLite database file
|
|
snapshot_id: Specific snapshot ID to load (loads latest if None)
|
|
|
|
Returns:
|
|
Dictionary containing portfolio state
|
|
|
|
Raises:
|
|
PersistenceError: If load operation fails
|
|
"""
|
|
try:
|
|
# Sanitize database name
|
|
safe_db_name = sanitize_path_component(db_name)
|
|
if not safe_db_name.endswith('.db'):
|
|
safe_db_name += '.db'
|
|
|
|
db_path = self.base_dir / safe_db_name
|
|
|
|
if not db_path.exists():
|
|
raise PersistenceError(f"Database not found: {db_path}")
|
|
|
|
with sqlite3.connect(db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
|
|
# Get snapshot
|
|
if snapshot_id is None:
|
|
# Get latest snapshot
|
|
cursor = conn.execute(
|
|
'SELECT * FROM portfolio_snapshots ORDER BY timestamp DESC LIMIT 1'
|
|
)
|
|
else:
|
|
cursor = conn.execute(
|
|
'SELECT * FROM portfolio_snapshots WHERE id = ?',
|
|
(snapshot_id,)
|
|
)
|
|
|
|
snapshot = cursor.fetchone()
|
|
if not snapshot:
|
|
raise PersistenceError("No portfolio snapshot found")
|
|
|
|
# Build portfolio data
|
|
portfolio_data = {
|
|
'cash': Decimal(snapshot['cash']),
|
|
'initial_capital': Decimal(snapshot['initial_capital']),
|
|
'commission_rate': Decimal(snapshot['commission_rate']),
|
|
'timestamp': snapshot['timestamp'],
|
|
}
|
|
|
|
# Load positions
|
|
portfolio_data['positions'] = self._load_positions(
|
|
conn, snapshot['id']
|
|
)
|
|
|
|
# Load trade history
|
|
portfolio_data['trade_history'] = self._load_trades(
|
|
conn, snapshot['id']
|
|
)
|
|
|
|
logger.info(f"Loaded portfolio state from SQLite: {db_path}")
|
|
|
|
return portfolio_data
|
|
|
|
except (sqlite3.Error, OSError) as e:
|
|
raise PersistenceError(f"Failed to load portfolio from SQLite: {e}")
|
|
|
|
def _create_tables(self, conn: sqlite3.Connection) -> None:
|
|
"""Create database tables if they don't exist."""
|
|
cursor = conn.cursor()
|
|
|
|
# Portfolio snapshots table
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS portfolio_snapshots (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
timestamp TEXT NOT NULL,
|
|
cash TEXT NOT NULL,
|
|
initial_capital TEXT NOT NULL,
|
|
commission_rate TEXT NOT NULL,
|
|
total_value TEXT,
|
|
metadata TEXT
|
|
)
|
|
''')
|
|
|
|
# Positions table
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS positions (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
snapshot_id INTEGER NOT NULL,
|
|
ticker TEXT NOT NULL,
|
|
quantity TEXT NOT NULL,
|
|
cost_basis TEXT NOT NULL,
|
|
sector TEXT,
|
|
opened_at TEXT NOT NULL,
|
|
last_updated TEXT NOT NULL,
|
|
stop_loss TEXT,
|
|
take_profit TEXT,
|
|
metadata TEXT,
|
|
FOREIGN KEY (snapshot_id) REFERENCES portfolio_snapshots (id)
|
|
)
|
|
''')
|
|
|
|
# Trade history table
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS trades (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
snapshot_id INTEGER NOT NULL,
|
|
ticker TEXT NOT NULL,
|
|
entry_date TEXT NOT NULL,
|
|
exit_date TEXT,
|
|
entry_price TEXT NOT NULL,
|
|
exit_price TEXT,
|
|
quantity TEXT NOT NULL,
|
|
pnl TEXT,
|
|
pnl_percent TEXT,
|
|
commission TEXT NOT NULL,
|
|
holding_period INTEGER,
|
|
is_win INTEGER,
|
|
FOREIGN KEY (snapshot_id) REFERENCES portfolio_snapshots (id)
|
|
)
|
|
''')
|
|
|
|
# Create indices for better query performance
|
|
cursor.execute(
|
|
'CREATE INDEX IF NOT EXISTS idx_positions_snapshot ON positions(snapshot_id)'
|
|
)
|
|
cursor.execute(
|
|
'CREATE INDEX IF NOT EXISTS idx_trades_snapshot ON trades(snapshot_id)'
|
|
)
|
|
cursor.execute(
|
|
'CREATE INDEX IF NOT EXISTS idx_trades_ticker ON trades(ticker)'
|
|
)
|
|
|
|
conn.commit()
|
|
|
|
def _insert_portfolio_snapshot(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
portfolio_data: Dict[str, Any]
|
|
) -> int:
|
|
"""Insert a portfolio snapshot and return its ID."""
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute('''
|
|
INSERT INTO portfolio_snapshots
|
|
(timestamp, cash, initial_capital, commission_rate, total_value, metadata)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
''', (
|
|
portfolio_data.get('timestamp', datetime.now().isoformat()),
|
|
str(portfolio_data.get('cash', '0')),
|
|
str(portfolio_data.get('initial_capital', '0')),
|
|
str(portfolio_data.get('commission_rate', '0')),
|
|
str(portfolio_data.get('total_value', '0')),
|
|
json.dumps(portfolio_data.get('metadata', {}))
|
|
))
|
|
|
|
conn.commit()
|
|
return cursor.lastrowid
|
|
|
|
def _insert_positions(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
positions: Dict[str, Dict[str, Any]]
|
|
) -> None:
|
|
"""Insert positions into the database."""
|
|
cursor = conn.cursor()
|
|
|
|
# Get the latest snapshot ID
|
|
snapshot_id = cursor.execute(
|
|
'SELECT MAX(id) FROM portfolio_snapshots'
|
|
).fetchone()[0]
|
|
|
|
for ticker, position_data in positions.items():
|
|
cursor.execute('''
|
|
INSERT INTO positions
|
|
(snapshot_id, ticker, quantity, cost_basis, sector, opened_at,
|
|
last_updated, stop_loss, take_profit, metadata)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
''', (
|
|
snapshot_id,
|
|
ticker,
|
|
str(position_data.get('quantity', '0')),
|
|
str(position_data.get('cost_basis', '0')),
|
|
position_data.get('sector'),
|
|
position_data.get('opened_at', datetime.now().isoformat()),
|
|
position_data.get('last_updated', datetime.now().isoformat()),
|
|
str(position_data.get('stop_loss')) if position_data.get('stop_loss') else None,
|
|
str(position_data.get('take_profit')) if position_data.get('take_profit') else None,
|
|
json.dumps(position_data.get('metadata', {}))
|
|
))
|
|
|
|
conn.commit()
|
|
|
|
def _insert_trades(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
trades: List[Dict[str, Any]]
|
|
) -> None:
|
|
"""Insert trades into the database."""
|
|
cursor = conn.cursor()
|
|
|
|
# Get the latest snapshot ID
|
|
snapshot_id = cursor.execute(
|
|
'SELECT MAX(id) FROM portfolio_snapshots'
|
|
).fetchone()[0]
|
|
|
|
for trade_data in trades:
|
|
cursor.execute('''
|
|
INSERT INTO trades
|
|
(snapshot_id, ticker, entry_date, exit_date, entry_price, exit_price,
|
|
quantity, pnl, pnl_percent, commission, holding_period, is_win)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
''', (
|
|
snapshot_id,
|
|
trade_data.get('ticker', ''),
|
|
trade_data.get('entry_date', ''),
|
|
trade_data.get('exit_date'),
|
|
str(trade_data.get('entry_price', '0')),
|
|
str(trade_data.get('exit_price')) if trade_data.get('exit_price') else None,
|
|
str(trade_data.get('quantity', '0')),
|
|
str(trade_data.get('pnl')) if trade_data.get('pnl') else None,
|
|
str(trade_data.get('pnl_percent')) if trade_data.get('pnl_percent') else None,
|
|
str(trade_data.get('commission', '0')),
|
|
trade_data.get('holding_period'),
|
|
1 if trade_data.get('is_win') else 0
|
|
))
|
|
|
|
conn.commit()
|
|
|
|
def _load_positions(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
snapshot_id: int
|
|
) -> Dict[str, Dict[str, Any]]:
|
|
"""Load positions from the database."""
|
|
cursor = conn.execute(
|
|
'SELECT * FROM positions WHERE snapshot_id = ?',
|
|
(snapshot_id,)
|
|
)
|
|
|
|
positions = {}
|
|
for row in cursor:
|
|
ticker = row['ticker']
|
|
positions[ticker] = {
|
|
'quantity': row['quantity'],
|
|
'cost_basis': row['cost_basis'],
|
|
'sector': row['sector'],
|
|
'opened_at': row['opened_at'],
|
|
'last_updated': row['last_updated'],
|
|
'stop_loss': row['stop_loss'],
|
|
'take_profit': row['take_profit'],
|
|
'metadata': json.loads(row['metadata']) if row['metadata'] else {}
|
|
}
|
|
|
|
return positions
|
|
|
|
def _load_trades(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
snapshot_id: int
|
|
) -> List[Dict[str, Any]]:
|
|
"""Load trades from the database."""
|
|
cursor = conn.execute(
|
|
'SELECT * FROM trades WHERE snapshot_id = ?',
|
|
(snapshot_id,)
|
|
)
|
|
|
|
trades = []
|
|
for row in cursor:
|
|
trades.append({
|
|
'ticker': row['ticker'],
|
|
'entry_date': row['entry_date'],
|
|
'exit_date': row['exit_date'],
|
|
'entry_price': row['entry_price'],
|
|
'exit_price': row['exit_price'],
|
|
'quantity': row['quantity'],
|
|
'pnl': row['pnl'],
|
|
'pnl_percent': row['pnl_percent'],
|
|
'commission': row['commission'],
|
|
'holding_period': row['holding_period'],
|
|
'is_win': bool(row['is_win'])
|
|
})
|
|
|
|
return trades
|
|
|
|
def _prepare_for_json(self, data: Any) -> Any:
|
|
"""Recursively prepare data for JSON serialization."""
|
|
if isinstance(data, Decimal):
|
|
return str(data)
|
|
elif isinstance(data, datetime):
|
|
return data.isoformat()
|
|
elif isinstance(data, dict):
|
|
return {k: self._prepare_for_json(v) for k, v in data.items()}
|
|
elif isinstance(data, list):
|
|
return [self._prepare_for_json(item) for item in data]
|
|
else:
|
|
return data
|
|
|
|
def _restore_from_json(self, data: Any) -> Any:
|
|
"""Recursively restore data types from JSON."""
|
|
if isinstance(data, dict):
|
|
# Check for known keys that should be Decimal
|
|
decimal_keys = {
|
|
'cash', 'initial_capital', 'commission_rate', 'quantity',
|
|
'cost_basis', 'stop_loss', 'take_profit', 'entry_price',
|
|
'exit_price', 'pnl', 'pnl_percent', 'commission', 'limit_price',
|
|
'stop_price', 'target_price', 'filled_price'
|
|
}
|
|
|
|
result = {}
|
|
for k, v in data.items():
|
|
if k in decimal_keys and v is not None:
|
|
try:
|
|
result[k] = Decimal(str(v))
|
|
except:
|
|
result[k] = v
|
|
else:
|
|
result[k] = self._restore_from_json(v)
|
|
|
|
return result
|
|
elif isinstance(data, list):
|
|
return [self._restore_from_json(item) for item in data]
|
|
else:
|
|
return data
|
|
|
|
def export_to_csv(
|
|
self,
|
|
trades: List[Dict[str, Any]],
|
|
filename: str
|
|
) -> None:
|
|
"""
|
|
Export trade history to CSV file.
|
|
|
|
Args:
|
|
trades: List of trade records
|
|
filename: Name of the CSV file
|
|
|
|
Raises:
|
|
PersistenceError: If export fails
|
|
"""
|
|
try:
|
|
import csv
|
|
|
|
safe_filename = sanitize_path_component(filename)
|
|
if not safe_filename.endswith('.csv'):
|
|
safe_filename += '.csv'
|
|
|
|
filepath = self.base_dir / safe_filename
|
|
|
|
if not trades:
|
|
logger.warning("No trades to export")
|
|
return
|
|
|
|
# Get all unique keys from trades
|
|
fieldnames = set()
|
|
for trade in trades:
|
|
fieldnames.update(trade.keys())
|
|
|
|
fieldnames = sorted(fieldnames)
|
|
|
|
with open(filepath, 'w', newline='') as f:
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
writer.writerows(trades)
|
|
|
|
logger.info(f"Exported {len(trades)} trades to {filepath}")
|
|
|
|
except (OSError, IOError) as e:
|
|
raise PersistenceError(f"Failed to export to CSV: {e}")
|
|
|
|
def cleanup_old_snapshots(
|
|
self,
|
|
db_name: str = 'portfolio.db',
|
|
keep_last_n: int = 100
|
|
) -> int:
|
|
"""
|
|
Clean up old snapshots from the database.
|
|
|
|
Args:
|
|
db_name: Name of the SQLite database file
|
|
keep_last_n: Number of latest snapshots to keep
|
|
|
|
Returns:
|
|
Number of snapshots deleted
|
|
|
|
Raises:
|
|
PersistenceError: If cleanup fails
|
|
"""
|
|
try:
|
|
safe_db_name = sanitize_path_component(db_name)
|
|
if not safe_db_name.endswith('.db'):
|
|
safe_db_name += '.db'
|
|
|
|
db_path = self.base_dir / safe_db_name
|
|
|
|
if not db_path.exists():
|
|
return 0
|
|
|
|
with sqlite3.connect(db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
# Get IDs of snapshots to delete
|
|
cursor.execute('''
|
|
SELECT id FROM portfolio_snapshots
|
|
ORDER BY timestamp DESC
|
|
LIMIT -1 OFFSET ?
|
|
''', (keep_last_n,))
|
|
|
|
ids_to_delete = [row[0] for row in cursor.fetchall()]
|
|
|
|
if not ids_to_delete:
|
|
return 0
|
|
|
|
# Delete related positions and trades
|
|
cursor.execute(
|
|
f'DELETE FROM positions WHERE snapshot_id IN ({",".join("?" * len(ids_to_delete))})',
|
|
ids_to_delete
|
|
)
|
|
cursor.execute(
|
|
f'DELETE FROM trades WHERE snapshot_id IN ({",".join("?" * len(ids_to_delete))})',
|
|
ids_to_delete
|
|
)
|
|
|
|
# Delete snapshots
|
|
cursor.execute(
|
|
f'DELETE FROM portfolio_snapshots WHERE id IN ({",".join("?" * len(ids_to_delete))})',
|
|
ids_to_delete
|
|
)
|
|
|
|
conn.commit()
|
|
|
|
logger.info(f"Deleted {len(ids_to_delete)} old snapshots")
|
|
|
|
return len(ids_to_delete)
|
|
|
|
except (sqlite3.Error, OSError) as e:
|
|
raise PersistenceError(f"Failed to cleanup old snapshots: {e}")
|