TradingAgents/tradingagents/llm_adapters/ernie_adapter.py

156 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 文心一言模型适配器
# 支持百度文心一言大模型
import os
import requests
import json
from typing import List, Dict, Any, Optional, Iterator
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.outputs import LLMResult, Generation, ChatResult, ChatGeneration
from pydantic import Field
class ErnieAdapter(BaseChatModel):
"""
文心一言模型适配器
支持百度文心一言大模型的调用
"""
model_name: str = Field(default="ernie-4.0-8k", description="模型名称")
api_key: Optional[str] = Field(default=None, description="API密钥")
secret_key: Optional[str] = Field(default=None, description="Secret密钥")
base_url: str = Field(default="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat", description="API基础URL")
temperature: float = Field(default=0.7, description="温度参数")
max_tokens: int = Field(default=2000, description="最大token数")
def __init__(
self,
model_name: str = "ernie-4.0-8k",
api_key: Optional[str] = None,
secret_key: Optional[str] = None,
base_url: str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
):
super().__init__(
model_name=model_name,
api_key=api_key or os.getenv("BAIDU_API_KEY"),
secret_key=secret_key or os.getenv("BAIDU_SECRET_KEY"),
base_url=base_url,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
if not self.api_key or not self.secret_key:
raise ValueError("请设置BAIDU_API_KEY和BAIDU_SECRET_KEY环境变量")
@property
def _llm_type(self) -> str:
return "ernie"
def _get_access_token(self) -> str:
"""获取访问令牌"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {
"grant_type": "client_credentials",
"client_id": self.api_key,
"client_secret": self.secret_key
}
response = requests.post(url, params=params)
response.raise_for_status()
result = response.json()
return result["access_token"]
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""生成回复"""
try:
# 获取访问令牌
access_token = self._get_access_token()
# 转换消息格式
formatted_messages = self._format_messages(messages)
# 构建请求数据
data = {
"messages": formatted_messages,
"temperature": self.temperature,
"max_output_tokens": self.max_tokens,
}
# 添加停止词
if stop:
data["stop"] = stop
# 发送请求
url = f"{self.base_url}/{self.model_name}?access_token={access_token}"
headers = {"Content-Type": "application/json"}
response = requests.post(url, headers=headers, json=data, timeout=30)
response.raise_for_status()
result = response.json()
# 解析响应
content = result["result"]
message = AIMessage(content=content)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
except Exception as e:
raise Exception(f"文心一言API调用失败: {str(e)}")
def _format_messages(self, messages: List[BaseMessage]) -> List[Dict[str, str]]:
"""将LangChain消息格式转换为文心一言格式"""
formatted = []
for message in messages:
if isinstance(message, SystemMessage):
formatted.append({"role": "user", "content": f"系统提示: {message.content}"})
elif isinstance(message, HumanMessage):
formatted.append({"role": "user", "content": message.content})
elif isinstance(message, AIMessage):
formatted.append({"role": "assistant", "content": message.content})
return formatted
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGeneration]:
"""流式生成(暂不支持)"""
# 对于不支持流式生成的模型,我们返回一个包含完整响应的生成
result = self._generate(messages, stop, run_manager, **kwargs)
for generation in result.generations:
yield generation
def bind_tools(self, tools, **kwargs):
"""绑定工具到模型(简化实现)"""
# 对于国内模型,我们简化工具绑定
# 直接返回self让上层处理工具调用
return self
@property
def _identifying_params(self) -> Dict[str, Any]:
"""返回识别参数"""
return {
"model_name": self.model_name,
"base_url": self.base_url,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
}