TradingAgents/tradingagents/llm_adapters/dashscope_adapter.py

289 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
阿里百炼大模型 (DashScope) 适配器
为 TradingAgents 提供阿里百炼大模型的 LangChain 兼容接口
"""
import os
import json
from typing import Any, Dict, List, Optional, Union, Iterator, AsyncIterator, Sequence
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import Field, SecretStr
import dashscope
from dashscope import Generation
from ..config.config_manager import token_tracker
class ChatDashScope(BaseChatModel):
"""阿里百炼大模型的 LangChain 适配器"""
# 模型配置
model: str = Field(default="qwen-turbo", description="DashScope 模型名称")
api_key: Optional[SecretStr] = Field(default=None, description="DashScope API 密钥")
temperature: float = Field(default=0.1, description="生成温度")
max_tokens: int = Field(default=2000, description="最大生成token数")
top_p: float = Field(default=0.9, description="核采样参数")
# 内部属性
_client: Any = None
def __init__(self, **kwargs):
"""初始化 DashScope 客户端"""
super().__init__(**kwargs)
# 设置API密钥
api_key = self.api_key
if api_key is None:
api_key = os.getenv("DASHSCOPE_API_KEY")
if api_key is None:
raise ValueError(
"DashScope API key not found. Please set DASHSCOPE_API_KEY environment variable "
"or pass api_key parameter."
)
# 配置 DashScope
if isinstance(api_key, SecretStr):
dashscope.api_key = api_key.get_secret_value()
else:
dashscope.api_key = api_key
@property
def _llm_type(self) -> str:
"""返回LLM类型"""
return "dashscope"
def _convert_messages_to_dashscope_format(self, messages: List[BaseMessage]) -> List[Dict[str, str]]:
"""将 LangChain 消息格式转换为 DashScope 格式"""
dashscope_messages = []
for message in messages:
if isinstance(message, SystemMessage):
role = "system"
elif isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
role = "assistant"
else:
# 默认作为用户消息处理
role = "user"
content = message.content
if isinstance(content, list):
# 处理多模态内容,目前只提取文本
text_content = ""
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
text_content += item.get("text", "")
content = text_content
dashscope_messages.append({
"role": role,
"content": str(content)
})
return dashscope_messages
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""生成聊天回复"""
# 转换消息格式
dashscope_messages = self._convert_messages_to_dashscope_format(messages)
# 准备请求参数
request_params = {
"model": self.model,
"messages": dashscope_messages,
"result_format": "message",
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
}
# 添加停止词
if stop:
request_params["stop"] = stop
# 合并额外参数
request_params.update(kwargs)
try:
# 调用 DashScope API
response = Generation.call(**request_params)
if response.status_code == 200:
# 解析响应
output = response.output
message_content = output.choices[0].message.content
# 提取token使用量信息
input_tokens = 0
output_tokens = 0
# DashScope API响应中包含usage信息
if hasattr(response, 'usage') and response.usage:
usage = response.usage
# 根据API文档usage可能包含input_tokens和output_tokens
if hasattr(usage, 'input_tokens'):
input_tokens = usage.input_tokens
if hasattr(usage, 'output_tokens'):
output_tokens = usage.output_tokens
# 有些情况下可能是total_tokens
elif hasattr(usage, 'total_tokens'):
# 估算输入和输出token如果没有分别提供
total_tokens = usage.total_tokens
# 简单估算假设输入占30%输出占70%
input_tokens = int(total_tokens * 0.3)
output_tokens = int(total_tokens * 0.7)
# 记录token使用量
if input_tokens > 0 or output_tokens > 0:
try:
# 生成会话ID如果没有提供
session_id = kwargs.get('session_id', f"dashscope_{hash(str(messages))%10000}")
analysis_type = kwargs.get('analysis_type', 'stock_analysis')
# 使用TokenTracker记录使用量
token_tracker.track_usage(
provider="dashscope",
model_name=self.model,
input_tokens=input_tokens,
output_tokens=output_tokens,
session_id=session_id,
analysis_type=analysis_type
)
except Exception as track_error:
# 记录失败不应该影响主要功能
print(f"Token tracking failed: {track_error}")
# 创建 AI 消息
ai_message = AIMessage(content=message_content)
# 创建生成结果
generation = ChatGeneration(message=ai_message)
return ChatResult(generations=[generation])
else:
raise Exception(f"DashScope API error: {response.code} - {response.message}")
except Exception as e:
raise Exception(f"Error calling DashScope API: {str(e)}")
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""异步生成聊天回复"""
# 目前使用同步方法,后续可以实现真正的异步
return self._generate(messages, stop, run_manager, **kwargs)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], type, BaseTool]],
**kwargs: Any,
) -> "ChatDashScope":
"""绑定工具到模型"""
# 注意DashScope 目前不直接支持工具调用
# 这里我们返回一个新的实例,但实际上工具调用需要在应用层处理
formatted_tools = []
for tool in tools:
if hasattr(tool, "name") and hasattr(tool, "description"):
# 这是一个 BaseTool 实例
formatted_tools.append({
"name": tool.name,
"description": tool.description,
"parameters": getattr(tool, "args_schema", {})
})
elif isinstance(tool, dict):
formatted_tools.append(tool)
else:
# 尝试转换为 OpenAI 工具格式
try:
formatted_tools.append(convert_to_openai_tool(tool))
except Exception:
pass
# 创建新实例,保存工具信息
new_instance = self.__class__(
model=self.model,
api_key=self.api_key,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
**kwargs
)
new_instance._tools = formatted_tools
return new_instance
@property
def _identifying_params(self) -> Dict[str, Any]:
"""返回标识参数"""
return {
"model": self.model,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
}
# 支持的模型列表
DASHSCOPE_MODELS = {
# 通义千问系列
"qwen-turbo": {
"description": "通义千问 Turbo - 快速响应,适合日常对话",
"context_length": 8192,
"recommended_for": ["快速任务", "日常对话", "简单分析"]
},
"qwen-plus": {
"description": "通义千问 Plus - 平衡性能和成本",
"context_length": 32768,
"recommended_for": ["复杂分析", "专业任务", "深度思考"]
},
"qwen-max": {
"description": "通义千问 Max - 最强性能",
"context_length": 32768,
"recommended_for": ["最复杂任务", "专业分析", "高质量输出"]
},
"qwen-max-longcontext": {
"description": "通义千问 Max 长文本版 - 支持超长上下文",
"context_length": 1000000,
"recommended_for": ["长文档分析", "大量数据处理", "复杂推理"]
},
}
def get_available_models() -> Dict[str, Dict[str, Any]]:
"""获取可用的 DashScope 模型列表"""
return DASHSCOPE_MODELS
def create_dashscope_llm(
model: str = "qwen-plus",
api_key: Optional[str] = None,
temperature: float = 0.1,
max_tokens: int = 2000,
**kwargs
) -> ChatDashScope:
"""创建 DashScope LLM 实例的便捷函数"""
return ChatDashScope(
model=model,
api_key=api_key,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)