#!/usr/bin/env python3 """ MongoDB存储适配器 用于将token使用记录存储到MongoDB数据库 """ import os from datetime import datetime from typing import Dict, List, Optional, Any from dataclasses import asdict from .config_manager import UsageRecord try: from pymongo import MongoClient from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError MONGODB_AVAILABLE = True except ImportError: MONGODB_AVAILABLE = False MongoClient = None class MongoDBStorage: """MongoDB存储适配器""" def __init__(self, connection_string: str = None, database_name: str = "tradingagents"): if not MONGODB_AVAILABLE: raise ImportError("pymongo is not installed. Please install it with: pip install pymongo") # 修复硬编码问题 - 如果没有提供连接字符串且环境变量也未设置,则抛出错误 self.connection_string = connection_string or os.getenv("MONGODB_CONNECTION_STRING") if not self.connection_string: raise ValueError( "MongoDB连接字符串未配置。请通过以下方式之一进行配置:\n" "1. 设置环境变量 MONGODB_CONNECTION_STRING\n" "2. 在初始化时传入 connection_string 参数\n" "例如: MONGODB_CONNECTION_STRING=mongodb://localhost:27017/" ) self.database_name = database_name self.collection_name = "token_usage" self.client = None self.db = None self.collection = None self._connected = False # 尝试连接 self._connect() def _connect(self): """连接到MongoDB""" try: self.client = MongoClient( self.connection_string, serverSelectionTimeoutMS=5000 # 5秒超时 ) # 测试连接 self.client.admin.command('ping') self.db = self.client[self.database_name] self.collection = self.db[self.collection_name] # 创建索引以提高查询性能 self._create_indexes() self._connected = True print(f"✅ MongoDB连接成功: {self.database_name}.{self.collection_name}") except (ConnectionFailure, ServerSelectionTimeoutError) as e: print(f"❌ MongoDB连接失败: {e}") print("将使用本地JSON文件存储") self._connected = False except Exception as e: print(f"❌ MongoDB初始化失败: {e}") self._connected = False def _create_indexes(self): """创建数据库索引""" try: # 创建复合索引 self.collection.create_index([ ("timestamp", -1), # 按时间倒序 ("provider", 1), ("model_name", 1) ]) # 创建会话ID索引 self.collection.create_index("session_id") # 创建分析类型索引 self.collection.create_index("analysis_type") except Exception as e: print(f"创建MongoDB索引失败: {e}") def is_connected(self) -> bool: """检查是否连接到MongoDB""" return self._connected def save_usage_record(self, record: UsageRecord) -> bool: """保存单个使用记录到MongoDB""" if not self._connected: return False try: # 转换为字典格式 record_dict = asdict(record) # 添加MongoDB特有的字段 record_dict['_created_at'] = datetime.now() # 插入记录 result = self.collection.insert_one(record_dict) if result.inserted_id: return True else: print("MongoDB插入失败:未返回插入ID") return False except Exception as e: print(f"保存记录到MongoDB失败: {e}") return False def load_usage_records(self, limit: int = 10000, days: int = None) -> List[UsageRecord]: """从MongoDB加载使用记录""" if not self._connected: return [] try: # 构建查询条件 query = {} if days: from datetime import timedelta cutoff_date = datetime.now() - timedelta(days=days) query['timestamp'] = {'$gte': cutoff_date.isoformat()} # 查询记录,按时间倒序 cursor = self.collection.find(query).sort('timestamp', -1).limit(limit) records = [] for doc in cursor: # 移除MongoDB特有的字段 doc.pop('_id', None) doc.pop('_created_at', None) # 转换为UsageRecord对象 try: record = UsageRecord(**doc) records.append(record) except Exception as e: print(f"解析记录失败: {e}, 记录: {doc}") continue return records except Exception as e: print(f"从MongoDB加载记录失败: {e}") return [] def get_usage_statistics(self, days: int = 30) -> Dict[str, Any]: """从MongoDB获取使用统计""" if not self._connected: return {} try: from datetime import timedelta cutoff_date = datetime.now() - timedelta(days=days) # 聚合查询 pipeline = [ { '$match': { 'timestamp': {'$gte': cutoff_date.isoformat()} } }, { '$group': { '_id': None, 'total_cost': {'$sum': '$cost'}, 'total_input_tokens': {'$sum': '$input_tokens'}, 'total_output_tokens': {'$sum': '$output_tokens'}, 'total_requests': {'$sum': 1} } } ] result = list(self.collection.aggregate(pipeline)) if result: stats = result[0] return { 'period_days': days, 'total_cost': round(stats.get('total_cost', 0), 4), 'total_input_tokens': stats.get('total_input_tokens', 0), 'total_output_tokens': stats.get('total_output_tokens', 0), 'total_requests': stats.get('total_requests', 0) } else: return { 'period_days': days, 'total_cost': 0, 'total_input_tokens': 0, 'total_output_tokens': 0, 'total_requests': 0 } except Exception as e: print(f"获取MongoDB统计失败: {e}") return {} def get_provider_statistics(self, days: int = 30) -> Dict[str, Dict[str, Any]]: """按供应商获取统计信息""" if not self._connected: return {} try: from datetime import timedelta cutoff_date = datetime.now() - timedelta(days=days) # 按供应商聚合 pipeline = [ { '$match': { 'timestamp': {'$gte': cutoff_date.isoformat()} } }, { '$group': { '_id': '$provider', 'cost': {'$sum': '$cost'}, 'input_tokens': {'$sum': '$input_tokens'}, 'output_tokens': {'$sum': '$output_tokens'}, 'requests': {'$sum': 1} } } ] results = list(self.collection.aggregate(pipeline)) provider_stats = {} for result in results: provider = result['_id'] provider_stats[provider] = { 'cost': round(result.get('cost', 0), 4), 'input_tokens': result.get('input_tokens', 0), 'output_tokens': result.get('output_tokens', 0), 'requests': result.get('requests', 0) } return provider_stats except Exception as e: print(f"获取供应商统计失败: {e}") return {} def cleanup_old_records(self, days: int = 90) -> int: """清理旧记录""" if not self._connected: return 0 try: from datetime import timedelta cutoff_date = datetime.now() - timedelta(days=days) result = self.collection.delete_many({ 'timestamp': {'$lt': cutoff_date.isoformat()} }) deleted_count = result.deleted_count if deleted_count > 0: print(f"清理了 {deleted_count} 条超过 {days} 天的记录") return deleted_count except Exception as e: print(f"清理旧记录失败: {e}") return 0 def close(self): """关闭MongoDB连接""" if self.client: self.client.close() self._connected = False print("MongoDB连接已关闭")