vibe-coding-cn/libs/external/chat-vault/src/storage.py

247 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""SQLite 存储模块 - 完整版"""
import sqlite3
import json
import os
import datetime
import tiktoken
SCHEMA_VERSION = 5
CLIS = ('codex', 'kiro', 'gemini', 'claude')
_encoder = tiktoken.get_encoding("cl100k_base")
def count_tokens(text: str) -> int:
return len(_encoder.encode(text)) if text else 0
class ChatStorage:
def __init__(self, db_path: str):
self.db_path = db_path
os.makedirs(os.path.dirname(db_path) or '.', exist_ok=True)
self._init_db()
def _conn(self):
return sqlite3.connect(self.db_path)
def _init_db(self):
with self._conn() as conn:
conn.execute('''CREATE TABLE IF NOT EXISTS meta (key TEXT PRIMARY KEY, value TEXT)''')
for cli in CLIS:
conn.execute(f'''CREATE TABLE IF NOT EXISTS meta_{cli} (key TEXT PRIMARY KEY, value TEXT)''')
conn.execute('''
CREATE TABLE IF NOT EXISTS sessions (
file_path TEXT PRIMARY KEY,
session_id TEXT,
source TEXT NOT NULL,
cwd TEXT,
messages TEXT,
file_mtime INTEGER,
start_time TEXT,
token_count INTEGER DEFAULT 0
)
''')
conn.execute('CREATE INDEX IF NOT EXISTS idx_source ON sessions(source)')
conn.execute('CREATE INDEX IF NOT EXISTS idx_session_id ON sessions(session_id)')
conn.execute('CREATE INDEX IF NOT EXISTS idx_start_time ON sessions(start_time)')
self._set_meta('meta', 'schema_version', str(SCHEMA_VERSION))
def _set_meta(self, table: str, key: str, value: str):
with self._conn() as conn:
conn.execute(f'INSERT OR REPLACE INTO {table} (key, value) VALUES (?, ?)', (key, value))
def _get_meta(self, table: str, key: str) -> str:
with self._conn() as conn:
row = conn.execute(f'SELECT value FROM {table} WHERE key = ?', (key,)).fetchone()
return row[0] if row else None
def update_cli_meta(self, cli: str, path: str, sessions: int, messages: int, tokens: int = None):
table = f'meta_{cli}'
now = datetime.datetime.now().isoformat()
# 顺序: path, sessions, messages, total_tokens, last_sync
self._set_meta(table, 'path', path)
self._set_meta(table, 'sessions', str(sessions))
self._set_meta(table, 'messages', str(messages))
self._set_meta(table, 'total_tokens', str(tokens or 0))
self._set_meta(table, 'last_sync', now)
def update_total_meta(self, sessions: int, messages: int, tokens: int = None):
now = datetime.datetime.now().isoformat()
self._set_meta('meta', 'total_sessions', str(sessions))
self._set_meta('meta', 'total_messages', str(messages))
if tokens is not None:
self._set_meta('meta', 'total_tokens', str(tokens))
self._set_meta('meta', 'last_sync', now)
def get_total_meta(self) -> dict:
return {
'schema_version': int(self._get_meta('meta', 'schema_version') or 0),
'total_sessions': int(self._get_meta('meta', 'total_sessions') or 0),
'total_messages': int(self._get_meta('meta', 'total_messages') or 0),
'last_sync': self._get_meta('meta', 'last_sync'),
}
def get_file_mtime(self, file_path: str) -> int:
with self._conn() as conn:
row = conn.execute('SELECT file_mtime FROM sessions WHERE file_path = ?', (file_path,)).fetchone()
return row[0] if row else 0
def upsert_session(self, session_id: str, source: str, file_path: str,
cwd: str, messages: list, file_mtime: int, start_time: str = None):
if file_path and not file_path.startswith('claude:') and not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
total_tokens = sum(count_tokens(msg.get('content', '')) for msg in messages)
if not start_time and messages:
start_time = messages[0].get('time')
messages_json = json.dumps(messages, ensure_ascii=False)
with self._conn() as conn:
conn.execute('''
INSERT INTO sessions (file_path, session_id, source, cwd, messages, file_mtime, start_time, token_count)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(file_path) DO UPDATE SET
session_id=excluded.session_id, messages=excluded.messages,
file_mtime=excluded.file_mtime, start_time=excluded.start_time, token_count=excluded.token_count
''', (file_path, session_id, source, cwd, messages_json, file_mtime, start_time, total_tokens))
def get_cli_stats(self, cli: str) -> dict:
with self._conn() as conn:
sessions = conn.execute('SELECT COUNT(*) FROM sessions WHERE source = ?', (cli,)).fetchone()[0]
row = conn.execute('SELECT SUM(json_array_length(messages)) FROM sessions WHERE source = ?', (cli,)).fetchone()
messages = row[0] or 0
tokens = conn.execute('SELECT SUM(token_count) FROM sessions WHERE source = ?', (cli,)).fetchone()[0] or 0
return {'sessions': sessions, 'messages': messages, 'tokens': tokens}
def get_total_stats(self) -> dict:
with self._conn() as conn:
sessions = conn.execute('SELECT COUNT(*) FROM sessions').fetchone()[0]
row = conn.execute('SELECT SUM(json_array_length(messages)) FROM sessions').fetchone()
messages = row[0] or 0
tokens = conn.execute('SELECT SUM(token_count) FROM sessions').fetchone()[0] or 0
return {'sessions': sessions, 'messages': messages, 'tokens': tokens}
def get_token_stats(self) -> dict:
with self._conn() as conn:
rows = conn.execute('SELECT source, SUM(token_count) FROM sessions GROUP BY source').fetchall()
return {r[0]: r[1] or 0 for r in rows}
# === 清理孤立记录 ===
def prune(self) -> dict:
"""删除源文件已不存在的记录"""
removed = {'codex': 0, 'kiro': 0, 'gemini': 0, 'claude': 0}
with self._conn() as conn:
rows = conn.execute('SELECT file_path, source FROM sessions').fetchall()
for fp, source in rows:
if fp.startswith('claude:'):
continue # Claude 使用虚拟路径
if not os.path.exists(fp):
conn.execute('DELETE FROM sessions WHERE file_path = ?', (fp,))
removed[source] = removed.get(source, 0) + 1
return removed
# === 查询 ===
def search(self, keyword: str, source: str = None, limit: int = 50) -> list:
"""搜索消息内容"""
sql = "SELECT file_path, session_id, source, cwd, messages, start_time FROM sessions WHERE messages LIKE ?"
params = [f'%{keyword}%']
if source:
sql += " AND source = ?"
params.append(source)
sql += f" ORDER BY start_time DESC LIMIT {limit}"
results = []
with self._conn() as conn:
for row in conn.execute(sql, params):
results.append({
'file_path': row[0], 'session_id': row[1], 'source': row[2],
'cwd': row[3], 'messages': json.loads(row[4]), 'start_time': row[5]
})
return results
def get_session(self, file_path: str) -> dict:
"""获取单个会话"""
with self._conn() as conn:
row = conn.execute(
'SELECT file_path, session_id, source, cwd, messages, start_time, token_count FROM sessions WHERE file_path = ?',
(file_path,)
).fetchone()
if not row:
return None
return {
'file_path': row[0], 'session_id': row[1], 'source': row[2], 'cwd': row[3],
'messages': json.loads(row[4]), 'start_time': row[5], 'token_count': row[6]
}
def list_sessions(self, source: str = None, limit: int = 100, offset: int = 0) -> list:
"""列出会话"""
sql = "SELECT file_path, session_id, source, cwd, start_time, token_count FROM sessions"
params = []
if source:
sql += " WHERE source = ?"
params.append(source)
sql += f" ORDER BY start_time DESC LIMIT {limit} OFFSET {offset}"
results = []
with self._conn() as conn:
for row in conn.execute(sql, params):
results.append({
'file_path': row[0], 'session_id': row[1], 'source': row[2],
'cwd': row[3], 'start_time': row[4], 'token_count': row[5]
})
return results
# === 导出 ===
def export_json(self, output_path: str, source: str = None):
"""导出为 JSON"""
sql = "SELECT file_path, session_id, source, cwd, messages, start_time, token_count FROM sessions"
params = []
if source:
sql += " WHERE source = ?"
params.append(source)
sql += " ORDER BY start_time"
data = []
with self._conn() as conn:
for row in conn.execute(sql, params):
data.append({
'file_path': row[0], 'session_id': row[1], 'source': row[2], 'cwd': row[3],
'messages': json.loads(row[4]), 'start_time': row[5], 'token_count': row[6]
})
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
return len(data)
def export_csv(self, output_path: str, source: str = None):
"""导出为 CSV扁平化消息"""
import csv
sql = "SELECT session_id, source, cwd, messages, start_time FROM sessions"
params = []
if source:
sql += " WHERE source = ?"
params.append(source)
sql += " ORDER BY start_time"
count = 0
with open(output_path, 'w', encoding='utf-8', newline='') as f:
writer = csv.writer(f)
writer.writerow(['session_id', 'source', 'cwd', 'time', 'role', 'content'])
with self._conn() as conn:
for row in conn.execute(sql, params):
session_id, src, cwd, msgs_json, _ = row
for msg in json.loads(msgs_json):
writer.writerow([session_id, src, cwd, msg.get('time', ''), msg.get('role', ''), msg.get('content', '')])
count += 1
return count
# === 获取所有文件路径(用于 prune 检查) ===
def get_all_file_paths(self, source: str = None) -> set:
sql = "SELECT file_path FROM sessions"
params = []
if source:
sql += " WHERE source = ?"
params.append(source)
with self._conn() as conn:
return {row[0] for row in conn.execute(sql, params)}