#!/usr/bin/env python3 """ 配置管理器 管理API密钥、模型配置、费率设置等 """ import json import os from datetime import datetime from typing import Dict, List, Optional, Any from dataclasses import dataclass, asdict from pathlib import Path from dotenv import load_dotenv try: from .mongodb_storage import MongoDBStorage MONGODB_AVAILABLE = True except ImportError: MONGODB_AVAILABLE = False MongoDBStorage = None @dataclass class ModelConfig: """模型配置""" provider: str # 供应商:dashscope, openai, google, etc. model_name: str # 模型名称 api_key: str # API密钥 base_url: Optional[str] = None # 自定义API地址 max_tokens: int = 4000 # 最大token数 temperature: float = 0.7 # 温度参数 enabled: bool = True # 是否启用 @dataclass class PricingConfig: """定价配置""" provider: str # 供应商 model_name: str # 模型名称 input_price_per_1k: float # 输入token价格(每1000个token) output_price_per_1k: float # 输出token价格(每1000个token) currency: str = "CNY" # 货币单位 @dataclass class UsageRecord: """使用记录""" timestamp: str # 时间戳 provider: str # 供应商 model_name: str # 模型名称 input_tokens: int # 输入token数 output_tokens: int # 输出token数 cost: float # 成本 session_id: str # 会话ID analysis_type: str # 分析类型 class ConfigManager: """配置管理器""" def __init__(self, config_dir: str = "config"): self.config_dir = Path(config_dir) self.config_dir.mkdir(exist_ok=True) self.models_file = self.config_dir / "models.json" self.pricing_file = self.config_dir / "pricing.json" self.usage_file = self.config_dir / "usage.json" self.settings_file = self.config_dir / "settings.json" # 加载.env文件(保持向后兼容) self._load_env_file() # 初始化MongoDB存储(如果可用) self.mongodb_storage = None self._init_mongodb_storage() self._init_default_configs() def _load_env_file(self): """加载.env文件(保持向后兼容)""" # 尝试从项目根目录加载.env文件 project_root = Path(__file__).parent.parent.parent env_file = project_root / ".env" if env_file.exists(): load_dotenv(env_file, override=True) def _get_env_api_key(self, provider: str) -> str: """从环境变量获取API密钥""" env_key_map = { "dashscope": "DASHSCOPE_API_KEY", "openai": "OPENAI_API_KEY", "google": "GOOGLE_API_KEY", "anthropic": "ANTHROPIC_API_KEY", "deepseek": "DEEPSEEK_API_KEY" } env_key = env_key_map.get(provider.lower()) if env_key: return os.getenv(env_key, "") return "" def _init_mongodb_storage(self): """初始化MongoDB存储""" if not MONGODB_AVAILABLE: return # 检查是否启用MongoDB存储 use_mongodb = os.getenv("USE_MONGODB_STORAGE", "false").lower() == "true" if not use_mongodb: return try: connection_string = os.getenv("MONGODB_CONNECTION_STRING") database_name = os.getenv("MONGODB_DATABASE_NAME", "tradingagents") self.mongodb_storage = MongoDBStorage( connection_string=connection_string, database_name=database_name ) if self.mongodb_storage.is_connected(): print("✅ MongoDB存储已启用") else: self.mongodb_storage = None print("⚠️ MongoDB连接失败,将使用JSON文件存储") except Exception as e: print(f"❌ MongoDB初始化失败: {e}") self.mongodb_storage = None def _init_default_configs(self): """初始化默认配置""" # 默认模型配置 if not self.models_file.exists(): default_models = [ ModelConfig( provider="dashscope", model_name="qwen-turbo", api_key="", max_tokens=4000, temperature=0.7 ), ModelConfig( provider="dashscope", model_name="qwen-plus-latest", api_key="", max_tokens=8000, temperature=0.7 ), ModelConfig( provider="openai", model_name="gpt-3.5-turbo", api_key="", max_tokens=4000, temperature=0.7, enabled=False ), ModelConfig( provider="openai", model_name="gpt-4", api_key="", max_tokens=8000, temperature=0.7, enabled=False ), ModelConfig( provider="google", model_name="gemini-pro", api_key="", max_tokens=4000, temperature=0.7, enabled=False ) ] self.save_models(default_models) # 默认定价配置 if not self.pricing_file.exists(): default_pricing = [ # 阿里百炼定价 (人民币) PricingConfig("dashscope", "qwen-turbo", 0.002, 0.006, "CNY"), PricingConfig("dashscope", "qwen-plus-latest", 0.004, 0.012, "CNY"), PricingConfig("dashscope", "qwen-max", 0.02, 0.06, "CNY"), # OpenAI定价 (美元) PricingConfig("openai", "gpt-3.5-turbo", 0.0015, 0.002, "USD"), PricingConfig("openai", "gpt-4", 0.03, 0.06, "USD"), PricingConfig("openai", "gpt-4-turbo", 0.01, 0.03, "USD"), # Google定价 (美元) PricingConfig("google", "gemini-pro", 0.00025, 0.0005, "USD"), PricingConfig("google", "gemini-pro-vision", 0.00025, 0.0005, "USD"), ] self.save_pricing(default_pricing) # 默认设置 if not self.settings_file.exists(): # 导入默认数据目录配置 import os default_data_dir = os.path.join(os.path.expanduser("~"), "Documents", "TradingAgents", "data") default_settings = { "default_provider": "dashscope", "default_model": "qwen-turbo", "enable_cost_tracking": True, "cost_alert_threshold": 100.0, # 成本警告阈值 "currency_preference": "CNY", "auto_save_usage": True, "max_usage_records": 10000, "data_dir": default_data_dir, # 数据目录配置 "cache_dir": os.path.join(default_data_dir, "cache"), # 缓存目录 "results_dir": os.path.join(os.path.expanduser("~"), "Documents", "TradingAgents", "results"), # 结果目录 "auto_create_dirs": True # 自动创建目录 } self.save_settings(default_settings) def load_models(self) -> List[ModelConfig]: """加载模型配置,优先使用.env中的API密钥""" try: with open(self.models_file, 'r', encoding='utf-8') as f: data = json.load(f) models = [ModelConfig(**item) for item in data] # 合并.env中的API密钥(优先级更高) for model in models: env_api_key = self._get_env_api_key(model.provider) if env_api_key: model.api_key = env_api_key # 如果.env中有API密钥,自动启用该模型 if not model.enabled: model.enabled = True return models except Exception as e: print(f"加载模型配置失败: {e}") return [] def save_models(self, models: List[ModelConfig]): """保存模型配置""" try: data = [asdict(model) for model in models] with open(self.models_file, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) except Exception as e: print(f"保存模型配置失败: {e}") def load_pricing(self) -> List[PricingConfig]: """加载定价配置""" try: with open(self.pricing_file, 'r', encoding='utf-8') as f: data = json.load(f) return [PricingConfig(**item) for item in data] except Exception as e: print(f"加载定价配置失败: {e}") return [] def save_pricing(self, pricing: List[PricingConfig]): """保存定价配置""" try: data = [asdict(price) for price in pricing] with open(self.pricing_file, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) except Exception as e: print(f"保存定价配置失败: {e}") def load_usage_records(self) -> List[UsageRecord]: """加载使用记录""" try: if not self.usage_file.exists(): return [] with open(self.usage_file, 'r', encoding='utf-8') as f: data = json.load(f) return [UsageRecord(**item) for item in data] except Exception as e: print(f"加载使用记录失败: {e}") return [] def save_usage_records(self, records: List[UsageRecord]): """保存使用记录""" try: data = [asdict(record) for record in records] with open(self.usage_file, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) except Exception as e: print(f"保存使用记录失败: {e}") def add_usage_record(self, provider: str, model_name: str, input_tokens: int, output_tokens: int, session_id: str, analysis_type: str = "stock_analysis"): """添加使用记录""" # 计算成本 cost = self.calculate_cost(provider, model_name, input_tokens, output_tokens) record = UsageRecord( timestamp=datetime.now().isoformat(), provider=provider, model_name=model_name, input_tokens=input_tokens, output_tokens=output_tokens, cost=cost, session_id=session_id, analysis_type=analysis_type ) # 优先使用MongoDB存储 if self.mongodb_storage and self.mongodb_storage.is_connected(): success = self.mongodb_storage.save_usage_record(record) if success: return record else: print("⚠️ MongoDB保存失败,回退到JSON文件存储") # 回退到JSON文件存储 records = self.load_usage_records() records.append(record) # 限制记录数量 settings = self.load_settings() max_records = settings.get("max_usage_records", 10000) if len(records) > max_records: records = records[-max_records:] self.save_usage_records(records) return record def calculate_cost(self, provider: str, model_name: str, input_tokens: int, output_tokens: int) -> float: """计算使用成本""" pricing_configs = self.load_pricing() for pricing in pricing_configs: if pricing.provider == provider and pricing.model_name == model_name: input_cost = (input_tokens / 1000) * pricing.input_price_per_1k output_cost = (output_tokens / 1000) * pricing.output_price_per_1k return round(input_cost + output_cost, 6) return 0.0 def load_settings(self) -> Dict[str, Any]: """加载设置,合并.env中的配置""" try: with open(self.settings_file, 'r', encoding='utf-8') as f: settings = json.load(f) except Exception as e: print(f"加载设置失败: {e}") settings = {} # 合并.env中的其他配置 env_settings = { "finnhub_api_key": os.getenv("FINNHUB_API_KEY", ""), "reddit_client_id": os.getenv("REDDIT_CLIENT_ID", ""), "reddit_client_secret": os.getenv("REDDIT_CLIENT_SECRET", ""), "reddit_user_agent": os.getenv("REDDIT_USER_AGENT", ""), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", ""), "log_level": os.getenv("TRADINGAGENTS_LOG_LEVEL", "INFO"), "data_dir": os.getenv("TRADINGAGENTS_DATA_DIR", ""), # 数据目录环境变量 "cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", ""), # 缓存目录环境变量 } # 只有当环境变量存在且不为空时才覆盖 for key, value in env_settings.items(): if value: settings[key] = value return settings def get_env_config_status(self) -> Dict[str, Any]: """获取.env配置状态""" return { "env_file_exists": (Path(__file__).parent.parent.parent / ".env").exists(), "api_keys": { "dashscope": bool(os.getenv("DASHSCOPE_API_KEY")), "openai": bool(os.getenv("OPENAI_API_KEY")), "google": bool(os.getenv("GOOGLE_API_KEY")), "anthropic": bool(os.getenv("ANTHROPIC_API_KEY")), "finnhub": bool(os.getenv("FINNHUB_API_KEY")), }, "other_configs": { "reddit_configured": bool(os.getenv("REDDIT_CLIENT_ID") and os.getenv("REDDIT_CLIENT_SECRET")), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), "log_level": os.getenv("TRADINGAGENTS_LOG_LEVEL", "INFO"), } } def save_settings(self, settings: Dict[str, Any]): """保存设置""" try: with open(self.settings_file, 'w', encoding='utf-8') as f: json.dump(settings, f, ensure_ascii=False, indent=2) except Exception as e: print(f"保存设置失败: {e}") def get_enabled_models(self) -> List[ModelConfig]: """获取启用的模型""" models = self.load_models() return [model for model in models if model.enabled and model.api_key] def get_model_by_name(self, provider: str, model_name: str) -> Optional[ModelConfig]: """根据名称获取模型配置""" models = self.load_models() for model in models: if model.provider == provider and model.model_name == model_name: return model return None def get_usage_statistics(self, days: int = 30) -> Dict[str, Any]: """获取使用统计""" # 优先使用MongoDB获取统计 if self.mongodb_storage and self.mongodb_storage.is_connected(): try: # 从MongoDB获取基础统计 stats = self.mongodb_storage.get_usage_statistics(days) # 获取供应商统计 provider_stats = self.mongodb_storage.get_provider_statistics(days) if stats: stats["provider_stats"] = provider_stats stats["records_count"] = stats.get("total_requests", 0) return stats except Exception as e: print(f"⚠️ MongoDB统计获取失败,回退到JSON文件: {e}") # 回退到JSON文件统计 records = self.load_usage_records() # 过滤最近N天的记录 from datetime import datetime, timedelta cutoff_date = datetime.now() - timedelta(days=days) recent_records = [] for record in records: try: record_date = datetime.fromisoformat(record.timestamp) if record_date >= cutoff_date: recent_records.append(record) except: continue # 统计数据 total_cost = sum(record.cost for record in recent_records) total_input_tokens = sum(record.input_tokens for record in recent_records) total_output_tokens = sum(record.output_tokens for record in recent_records) # 按供应商统计 provider_stats = {} for record in recent_records: if record.provider not in provider_stats: provider_stats[record.provider] = { "cost": 0, "input_tokens": 0, "output_tokens": 0, "requests": 0 } provider_stats[record.provider]["cost"] += record.cost provider_stats[record.provider]["input_tokens"] += record.input_tokens provider_stats[record.provider]["output_tokens"] += record.output_tokens provider_stats[record.provider]["requests"] += 1 return { "period_days": days, "total_cost": round(total_cost, 4), "total_input_tokens": total_input_tokens, "total_output_tokens": total_output_tokens, "total_requests": len(recent_records), "provider_stats": provider_stats, "records_count": len(recent_records) } def get_data_dir(self) -> str: """获取数据目录路径""" settings = self.load_settings() data_dir = settings.get("data_dir") if not data_dir: # 如果没有配置,使用默认路径 data_dir = os.path.join(os.path.expanduser("~"), "Documents", "TradingAgents", "data") return data_dir def set_data_dir(self, data_dir: str): """设置数据目录路径""" settings = self.load_settings() settings["data_dir"] = data_dir # 同时更新缓存目录 settings["cache_dir"] = os.path.join(data_dir, "cache") self.save_settings(settings) # 如果启用自动创建目录,则创建目录 if settings.get("auto_create_dirs", True): self.ensure_directories_exist() def ensure_directories_exist(self): """确保必要的目录存在""" settings = self.load_settings() directories = [ settings.get("data_dir"), settings.get("cache_dir"), settings.get("results_dir"), os.path.join(settings.get("data_dir", ""), "finnhub_data"), os.path.join(settings.get("data_dir", ""), "finnhub_data", "news_data"), os.path.join(settings.get("data_dir", ""), "finnhub_data", "insider_sentiment"), os.path.join(settings.get("data_dir", ""), "finnhub_data", "insider_transactions") ] for directory in directories: if directory and not os.path.exists(directory): try: os.makedirs(directory, exist_ok=True) print(f"✅ 创建目录: {directory}") except Exception as e: print(f"❌ 创建目录失败 {directory}: {e}") class TokenTracker: """Token使用跟踪器""" def __init__(self, config_manager: ConfigManager): self.config_manager = config_manager def track_usage(self, provider: str, model_name: str, input_tokens: int, output_tokens: int, session_id: str = None, analysis_type: str = "stock_analysis"): """跟踪Token使用""" if session_id is None: session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}" # 检查是否启用成本跟踪 settings = self.config_manager.load_settings() if not settings.get("enable_cost_tracking", True): return None # 添加使用记录 record = self.config_manager.add_usage_record( provider=provider, model_name=model_name, input_tokens=input_tokens, output_tokens=output_tokens, session_id=session_id, analysis_type=analysis_type ) # 检查成本警告 self._check_cost_alert(record.cost) return record def _check_cost_alert(self, current_cost: float): """检查成本警告""" settings = self.config_manager.load_settings() threshold = settings.get("cost_alert_threshold", 100.0) # 获取今日总成本 today_stats = self.config_manager.get_usage_statistics(1) total_today = today_stats["total_cost"] if total_today >= threshold: print(f"⚠️ 成本警告: 今日成本已达到 ¥{total_today:.4f},超过阈值 ¥{threshold}") def get_session_cost(self, session_id: str) -> float: """获取会话成本""" records = self.config_manager.load_usage_records() session_cost = sum(record.cost for record in records if record.session_id == session_id) return session_cost def estimate_cost(self, provider: str, model_name: str, estimated_input_tokens: int, estimated_output_tokens: int) -> float: """估算成本""" return self.config_manager.calculate_cost( provider, model_name, estimated_input_tokens, estimated_output_tokens ) # 全局配置管理器实例 config_manager = ConfigManager() token_tracker = TokenTracker(config_manager)