""" 阿里百炼大模型 (DashScope) 适配器 为 TradingAgents 提供阿里百炼大模型的 LangChain 兼容接口 """ import os import json from typing import Any, Dict, List, Optional, Union, Iterator, AsyncIterator, Sequence from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool from pydantic import Field, SecretStr import dashscope from dashscope import Generation from ..config.config_manager import token_tracker class ChatDashScope(BaseChatModel): """阿里百炼大模型的 LangChain 适配器""" # 模型配置 model: str = Field(default="qwen-turbo", description="DashScope 模型名称") api_key: Optional[SecretStr] = Field(default=None, description="DashScope API 密钥") temperature: float = Field(default=0.1, description="生成温度") max_tokens: int = Field(default=2000, description="最大生成token数") top_p: float = Field(default=0.9, description="核采样参数") # 内部属性 _client: Any = None def __init__(self, **kwargs): """初始化 DashScope 客户端""" super().__init__(**kwargs) # 设置API密钥 api_key = self.api_key if api_key is None: api_key = os.getenv("DASHSCOPE_API_KEY") if api_key is None: raise ValueError( "DashScope API key not found. Please set DASHSCOPE_API_KEY environment variable " "or pass api_key parameter." ) # 配置 DashScope if isinstance(api_key, SecretStr): dashscope.api_key = api_key.get_secret_value() else: dashscope.api_key = api_key @property def _llm_type(self) -> str: """返回LLM类型""" return "dashscope" def _convert_messages_to_dashscope_format(self, messages: List[BaseMessage]) -> List[Dict[str, str]]: """将 LangChain 消息格式转换为 DashScope 格式""" dashscope_messages = [] for message in messages: if isinstance(message, SystemMessage): role = "system" elif isinstance(message, HumanMessage): role = "user" elif isinstance(message, AIMessage): role = "assistant" else: # 默认作为用户消息处理 role = "user" content = message.content if isinstance(content, list): # 处理多模态内容,目前只提取文本 text_content = "" for item in content: if isinstance(item, dict) and item.get("type") == "text": text_content += item.get("text", "") content = text_content dashscope_messages.append({ "role": role, "content": str(content) }) return dashscope_messages def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: """生成聊天回复""" # 转换消息格式 dashscope_messages = self._convert_messages_to_dashscope_format(messages) # 准备请求参数 request_params = { "model": self.model, "messages": dashscope_messages, "result_format": "message", "temperature": self.temperature, "max_tokens": self.max_tokens, "top_p": self.top_p, } # 添加停止词 if stop: request_params["stop"] = stop # 合并额外参数 request_params.update(kwargs) try: # 调用 DashScope API response = Generation.call(**request_params) if response.status_code == 200: # 解析响应 output = response.output message_content = output.choices[0].message.content # 提取token使用量信息 input_tokens = 0 output_tokens = 0 # DashScope API响应中包含usage信息 if hasattr(response, 'usage') and response.usage: usage = response.usage # 根据API文档,usage可能包含input_tokens和output_tokens if hasattr(usage, 'input_tokens'): input_tokens = usage.input_tokens if hasattr(usage, 'output_tokens'): output_tokens = usage.output_tokens # 有些情况下可能是total_tokens elif hasattr(usage, 'total_tokens'): # 估算输入和输出token(如果没有分别提供) total_tokens = usage.total_tokens # 简单估算:假设输入占30%,输出占70% input_tokens = int(total_tokens * 0.3) output_tokens = int(total_tokens * 0.7) # 记录token使用量 if input_tokens > 0 or output_tokens > 0: try: # 生成会话ID(如果没有提供) session_id = kwargs.get('session_id', f"dashscope_{hash(str(messages))%10000}") analysis_type = kwargs.get('analysis_type', 'stock_analysis') # 使用TokenTracker记录使用量 token_tracker.track_usage( provider="dashscope", model_name=self.model, input_tokens=input_tokens, output_tokens=output_tokens, session_id=session_id, analysis_type=analysis_type ) except Exception as track_error: # 记录失败不应该影响主要功能 print(f"Token tracking failed: {track_error}") # 创建 AI 消息 ai_message = AIMessage(content=message_content) # 创建生成结果 generation = ChatGeneration(message=ai_message) return ChatResult(generations=[generation]) else: raise Exception(f"DashScope API error: {response.code} - {response.message}") except Exception as e: raise Exception(f"Error calling DashScope API: {str(e)}") async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: """异步生成聊天回复""" # 目前使用同步方法,后续可以实现真正的异步 return self._generate(messages, stop, run_manager, **kwargs) def bind_tools( self, tools: Sequence[Union[Dict[str, Any], type, BaseTool]], **kwargs: Any, ) -> "ChatDashScope": """绑定工具到模型""" # 注意:DashScope 目前不直接支持工具调用 # 这里我们返回一个新的实例,但实际上工具调用需要在应用层处理 formatted_tools = [] for tool in tools: if hasattr(tool, "name") and hasattr(tool, "description"): # 这是一个 BaseTool 实例 formatted_tools.append({ "name": tool.name, "description": tool.description, "parameters": getattr(tool, "args_schema", {}) }) elif isinstance(tool, dict): formatted_tools.append(tool) else: # 尝试转换为 OpenAI 工具格式 try: formatted_tools.append(convert_to_openai_tool(tool)) except Exception: pass # 创建新实例,保存工具信息 new_instance = self.__class__( model=self.model, api_key=self.api_key, temperature=self.temperature, max_tokens=self.max_tokens, top_p=self.top_p, **kwargs ) new_instance._tools = formatted_tools return new_instance @property def _identifying_params(self) -> Dict[str, Any]: """返回标识参数""" return { "model": self.model, "temperature": self.temperature, "max_tokens": self.max_tokens, "top_p": self.top_p, } # 支持的模型列表 DASHSCOPE_MODELS = { # 通义千问系列 "qwen-turbo": { "description": "通义千问 Turbo - 快速响应,适合日常对话", "context_length": 8192, "recommended_for": ["快速任务", "日常对话", "简单分析"] }, "qwen-plus": { "description": "通义千问 Plus - 平衡性能和成本", "context_length": 32768, "recommended_for": ["复杂分析", "专业任务", "深度思考"] }, "qwen-max": { "description": "通义千问 Max - 最强性能", "context_length": 32768, "recommended_for": ["最复杂任务", "专业分析", "高质量输出"] }, "qwen-max-longcontext": { "description": "通义千问 Max 长文本版 - 支持超长上下文", "context_length": 1000000, "recommended_for": ["长文档分析", "大量数据处理", "复杂推理"] }, } def get_available_models() -> Dict[str, Dict[str, Any]]: """获取可用的 DashScope 模型列表""" return DASHSCOPE_MODELS def create_dashscope_llm( model: str = "qwen-plus", api_key: Optional[str] = None, temperature: float = 0.1, max_tokens: int = 2000, **kwargs ) -> ChatDashScope: """创建 DashScope LLM 实例的便捷函数""" return ChatDashScope( model=model, api_key=api_key, temperature=temperature, max_tokens=max_tokens, **kwargs )