247 lines
11 KiB
Python
247 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""SQLite 存储模块 - 完整版"""
|
||
import sqlite3
|
||
import json
|
||
import os
|
||
import datetime
|
||
import tiktoken
|
||
|
||
SCHEMA_VERSION = 5
|
||
CLIS = ('codex', 'kiro', 'gemini', 'claude')
|
||
|
||
_encoder = tiktoken.get_encoding("cl100k_base")
|
||
|
||
def count_tokens(text: str) -> int:
|
||
return len(_encoder.encode(text)) if text else 0
|
||
|
||
class ChatStorage:
|
||
def __init__(self, db_path: str):
|
||
self.db_path = db_path
|
||
os.makedirs(os.path.dirname(db_path) or '.', exist_ok=True)
|
||
self._init_db()
|
||
|
||
def _conn(self):
|
||
return sqlite3.connect(self.db_path)
|
||
|
||
def _init_db(self):
|
||
with self._conn() as conn:
|
||
conn.execute('''CREATE TABLE IF NOT EXISTS meta (key TEXT PRIMARY KEY, value TEXT)''')
|
||
for cli in CLIS:
|
||
conn.execute(f'''CREATE TABLE IF NOT EXISTS meta_{cli} (key TEXT PRIMARY KEY, value TEXT)''')
|
||
conn.execute('''
|
||
CREATE TABLE IF NOT EXISTS sessions (
|
||
file_path TEXT PRIMARY KEY,
|
||
session_id TEXT,
|
||
source TEXT NOT NULL,
|
||
cwd TEXT,
|
||
messages TEXT,
|
||
file_mtime INTEGER,
|
||
start_time TEXT,
|
||
token_count INTEGER DEFAULT 0
|
||
)
|
||
''')
|
||
conn.execute('CREATE INDEX IF NOT EXISTS idx_source ON sessions(source)')
|
||
conn.execute('CREATE INDEX IF NOT EXISTS idx_session_id ON sessions(session_id)')
|
||
conn.execute('CREATE INDEX IF NOT EXISTS idx_start_time ON sessions(start_time)')
|
||
self._set_meta('meta', 'schema_version', str(SCHEMA_VERSION))
|
||
|
||
def _set_meta(self, table: str, key: str, value: str):
|
||
with self._conn() as conn:
|
||
conn.execute(f'INSERT OR REPLACE INTO {table} (key, value) VALUES (?, ?)', (key, value))
|
||
|
||
def _get_meta(self, table: str, key: str) -> str:
|
||
with self._conn() as conn:
|
||
row = conn.execute(f'SELECT value FROM {table} WHERE key = ?', (key,)).fetchone()
|
||
return row[0] if row else None
|
||
|
||
def update_cli_meta(self, cli: str, path: str, sessions: int, messages: int, tokens: int = None):
|
||
table = f'meta_{cli}'
|
||
now = datetime.datetime.now().isoformat()
|
||
# 顺序: path, sessions, messages, total_tokens, last_sync
|
||
self._set_meta(table, 'path', path)
|
||
self._set_meta(table, 'sessions', str(sessions))
|
||
self._set_meta(table, 'messages', str(messages))
|
||
self._set_meta(table, 'total_tokens', str(tokens or 0))
|
||
self._set_meta(table, 'last_sync', now)
|
||
|
||
def update_total_meta(self, sessions: int, messages: int, tokens: int = None):
|
||
now = datetime.datetime.now().isoformat()
|
||
self._set_meta('meta', 'total_sessions', str(sessions))
|
||
self._set_meta('meta', 'total_messages', str(messages))
|
||
if tokens is not None:
|
||
self._set_meta('meta', 'total_tokens', str(tokens))
|
||
self._set_meta('meta', 'last_sync', now)
|
||
|
||
def get_total_meta(self) -> dict:
|
||
return {
|
||
'schema_version': int(self._get_meta('meta', 'schema_version') or 0),
|
||
'total_sessions': int(self._get_meta('meta', 'total_sessions') or 0),
|
||
'total_messages': int(self._get_meta('meta', 'total_messages') or 0),
|
||
'last_sync': self._get_meta('meta', 'last_sync'),
|
||
}
|
||
|
||
def get_file_mtime(self, file_path: str) -> int:
|
||
with self._conn() as conn:
|
||
row = conn.execute('SELECT file_mtime FROM sessions WHERE file_path = ?', (file_path,)).fetchone()
|
||
return row[0] if row else 0
|
||
|
||
def upsert_session(self, session_id: str, source: str, file_path: str,
|
||
cwd: str, messages: list, file_mtime: int, start_time: str = None):
|
||
if file_path and not file_path.startswith('claude:') and not os.path.isabs(file_path):
|
||
file_path = os.path.abspath(file_path)
|
||
|
||
total_tokens = sum(count_tokens(msg.get('content', '')) for msg in messages)
|
||
if not start_time and messages:
|
||
start_time = messages[0].get('time')
|
||
|
||
messages_json = json.dumps(messages, ensure_ascii=False)
|
||
with self._conn() as conn:
|
||
conn.execute('''
|
||
INSERT INTO sessions (file_path, session_id, source, cwd, messages, file_mtime, start_time, token_count)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||
ON CONFLICT(file_path) DO UPDATE SET
|
||
session_id=excluded.session_id, messages=excluded.messages,
|
||
file_mtime=excluded.file_mtime, start_time=excluded.start_time, token_count=excluded.token_count
|
||
''', (file_path, session_id, source, cwd, messages_json, file_mtime, start_time, total_tokens))
|
||
|
||
def get_cli_stats(self, cli: str) -> dict:
|
||
with self._conn() as conn:
|
||
sessions = conn.execute('SELECT COUNT(*) FROM sessions WHERE source = ?', (cli,)).fetchone()[0]
|
||
row = conn.execute('SELECT SUM(json_array_length(messages)) FROM sessions WHERE source = ?', (cli,)).fetchone()
|
||
messages = row[0] or 0
|
||
tokens = conn.execute('SELECT SUM(token_count) FROM sessions WHERE source = ?', (cli,)).fetchone()[0] or 0
|
||
return {'sessions': sessions, 'messages': messages, 'tokens': tokens}
|
||
|
||
def get_total_stats(self) -> dict:
|
||
with self._conn() as conn:
|
||
sessions = conn.execute('SELECT COUNT(*) FROM sessions').fetchone()[0]
|
||
row = conn.execute('SELECT SUM(json_array_length(messages)) FROM sessions').fetchone()
|
||
messages = row[0] or 0
|
||
tokens = conn.execute('SELECT SUM(token_count) FROM sessions').fetchone()[0] or 0
|
||
return {'sessions': sessions, 'messages': messages, 'tokens': tokens}
|
||
|
||
def get_token_stats(self) -> dict:
|
||
with self._conn() as conn:
|
||
rows = conn.execute('SELECT source, SUM(token_count) FROM sessions GROUP BY source').fetchall()
|
||
return {r[0]: r[1] or 0 for r in rows}
|
||
|
||
# === 清理孤立记录 ===
|
||
def prune(self) -> dict:
|
||
"""删除源文件已不存在的记录"""
|
||
removed = {'codex': 0, 'kiro': 0, 'gemini': 0, 'claude': 0}
|
||
with self._conn() as conn:
|
||
rows = conn.execute('SELECT file_path, source FROM sessions').fetchall()
|
||
for fp, source in rows:
|
||
if fp.startswith('claude:'):
|
||
continue # Claude 使用虚拟路径
|
||
if not os.path.exists(fp):
|
||
conn.execute('DELETE FROM sessions WHERE file_path = ?', (fp,))
|
||
removed[source] = removed.get(source, 0) + 1
|
||
return removed
|
||
|
||
# === 查询 ===
|
||
def search(self, keyword: str, source: str = None, limit: int = 50) -> list:
|
||
"""搜索消息内容"""
|
||
sql = "SELECT file_path, session_id, source, cwd, messages, start_time FROM sessions WHERE messages LIKE ?"
|
||
params = [f'%{keyword}%']
|
||
if source:
|
||
sql += " AND source = ?"
|
||
params.append(source)
|
||
sql += f" ORDER BY start_time DESC LIMIT {limit}"
|
||
|
||
results = []
|
||
with self._conn() as conn:
|
||
for row in conn.execute(sql, params):
|
||
results.append({
|
||
'file_path': row[0], 'session_id': row[1], 'source': row[2],
|
||
'cwd': row[3], 'messages': json.loads(row[4]), 'start_time': row[5]
|
||
})
|
||
return results
|
||
|
||
def get_session(self, file_path: str) -> dict:
|
||
"""获取单个会话"""
|
||
with self._conn() as conn:
|
||
row = conn.execute(
|
||
'SELECT file_path, session_id, source, cwd, messages, start_time, token_count FROM sessions WHERE file_path = ?',
|
||
(file_path,)
|
||
).fetchone()
|
||
if not row:
|
||
return None
|
||
return {
|
||
'file_path': row[0], 'session_id': row[1], 'source': row[2], 'cwd': row[3],
|
||
'messages': json.loads(row[4]), 'start_time': row[5], 'token_count': row[6]
|
||
}
|
||
|
||
def list_sessions(self, source: str = None, limit: int = 100, offset: int = 0) -> list:
|
||
"""列出会话"""
|
||
sql = "SELECT file_path, session_id, source, cwd, start_time, token_count FROM sessions"
|
||
params = []
|
||
if source:
|
||
sql += " WHERE source = ?"
|
||
params.append(source)
|
||
sql += f" ORDER BY start_time DESC LIMIT {limit} OFFSET {offset}"
|
||
|
||
results = []
|
||
with self._conn() as conn:
|
||
for row in conn.execute(sql, params):
|
||
results.append({
|
||
'file_path': row[0], 'session_id': row[1], 'source': row[2],
|
||
'cwd': row[3], 'start_time': row[4], 'token_count': row[5]
|
||
})
|
||
return results
|
||
|
||
# === 导出 ===
|
||
def export_json(self, output_path: str, source: str = None):
|
||
"""导出为 JSON"""
|
||
sql = "SELECT file_path, session_id, source, cwd, messages, start_time, token_count FROM sessions"
|
||
params = []
|
||
if source:
|
||
sql += " WHERE source = ?"
|
||
params.append(source)
|
||
sql += " ORDER BY start_time"
|
||
|
||
data = []
|
||
with self._conn() as conn:
|
||
for row in conn.execute(sql, params):
|
||
data.append({
|
||
'file_path': row[0], 'session_id': row[1], 'source': row[2], 'cwd': row[3],
|
||
'messages': json.loads(row[4]), 'start_time': row[5], 'token_count': row[6]
|
||
})
|
||
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
return len(data)
|
||
|
||
def export_csv(self, output_path: str, source: str = None):
|
||
"""导出为 CSV(扁平化消息)"""
|
||
import csv
|
||
sql = "SELECT session_id, source, cwd, messages, start_time FROM sessions"
|
||
params = []
|
||
if source:
|
||
sql += " WHERE source = ?"
|
||
params.append(source)
|
||
sql += " ORDER BY start_time"
|
||
|
||
count = 0
|
||
with open(output_path, 'w', encoding='utf-8', newline='') as f:
|
||
writer = csv.writer(f)
|
||
writer.writerow(['session_id', 'source', 'cwd', 'time', 'role', 'content'])
|
||
with self._conn() as conn:
|
||
for row in conn.execute(sql, params):
|
||
session_id, src, cwd, msgs_json, _ = row
|
||
for msg in json.loads(msgs_json):
|
||
writer.writerow([session_id, src, cwd, msg.get('time', ''), msg.get('role', ''), msg.get('content', '')])
|
||
count += 1
|
||
return count
|
||
|
||
# === 获取所有文件路径(用于 prune 检查) ===
|
||
def get_all_file_paths(self, source: str = None) -> set:
|
||
sql = "SELECT file_path FROM sessions"
|
||
params = []
|
||
if source:
|
||
sql += " WHERE source = ?"
|
||
params.append(source)
|
||
with self._conn() as conn:
|
||
return {row[0] for row in conn.execute(sql, params)}
|