TradingAgents/tradingagents/agents/utils/dapt_llm.py

228 lines
8.9 KiB
Python

"""
LangChain wrapper for DAPTed Llama 3.1 8B model (PEFT adapter)
"""
import torch
from typing import List, Optional, Any, Dict
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.runnables import Runnable
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os
class DAPTLlamaChatModel(BaseChatModel):
"""LangChain-compatible wrapper for DAPTed Llama 3.1 8B model"""
model_id: str = "meta-llama/Llama-3.1-8B"
dapt_adapter_path: str
device: Optional[str] = None
max_new_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.9
_model: Optional[Any] = None
_tokenizer: Optional[Any] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._load_model()
def _load_model(self):
"""Load the DAPTed model with PEFT adapters"""
if self._model is not None:
return # Already loaded
print(f"Loading DAPTed Llama model from {self.dapt_adapter_path}...")
# Detect device
if self.device is None:
if torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
# Setup quantization for CUDA
bnb_config = None
if self.device == "cuda":
try:
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
bnb_4bit_use_double_quant=True,
)
except ImportError:
print("bitsandbytes not available, loading in full precision")
# Determine torch dtype
if self.device == "cuda":
torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
torch_dtype = torch.float32
# Get HF token
hf_token = (
os.getenv("HF_TOKEN")
or os.getenv("HUGGINGFACE_HUB_TOKEN")
or os.getenv("HUGGING_FACE_HUB_TOKEN")
)
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto" if self.device == "cuda" else None,
torch_dtype=torch_dtype,
quantization_config=bnb_config,
low_cpu_mem_usage=True,
token=hf_token,
)
# Load PEFT adapters
if not os.path.exists(self.dapt_adapter_path):
raise ValueError(f"DAPT adapter path not found: {self.dapt_adapter_path}")
self._model = PeftModel.from_pretrained(base_model, self.dapt_adapter_path)
self._model.eval()
# Load tokenizer
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=hf_token)
if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token
print(f"DAPTed model loaded successfully on {self.device}")
def _format_messages(self, messages: List[BaseMessage]) -> str:
"""Convert LangChain messages to Llama 3.1 chat format"""
# Use tokenizer's chat template if available
if hasattr(self._tokenizer, 'apply_chat_template') and self._tokenizer.chat_template:
# Convert to format expected by tokenizer
formatted_messages = []
for msg in messages:
if isinstance(msg, SystemMessage):
formatted_messages.append({"role": "system", "content": msg.content})
elif isinstance(msg, HumanMessage):
formatted_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
formatted_messages.append({"role": "assistant", "content": msg.content})
else:
formatted_messages.append({"role": "user", "content": str(msg.content)})
# Apply chat template
prompt = self._tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True
)
return prompt
else:
# Fallback to Llama 3 style manual formatting using header tokens
# <|begin_of_text|>
# <|start_header_id|>system<|end_header_id|>
# {content}<|eot_id|> ...
bos = "<|begin_of_text|>"
start_header = "<|start_header_id|>"
end_header = "<|end_header_id|>"
eot = "<|eot_id|>"
parts: List[str] = [bos]
for msg in messages:
if isinstance(msg, SystemMessage):
parts.append(f"{start_header}system{end_header}\n{msg.content}{eot}\n")
elif isinstance(msg, HumanMessage):
parts.append(f"{start_header}user{end_header}\n{msg.content}{eot}\n")
elif isinstance(msg, AIMessage):
parts.append(f"{start_header}assistant{end_header}\n{msg.content}{eot}\n")
else:
parts.append(f"{start_header}user{end_header}\n{str(msg.content)}{eot}\n")
# Add assistant header to cue generation
parts.append(f"{start_header}assistant{end_header}\n")
return "".join(parts)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response from the model"""
if self._model is None or self._tokenizer is None:
self._load_model()
# Format messages
prompt = self._format_messages(messages)
# Tokenize
inputs = self._tokenizer(prompt, return_tensors="pt")
if self.device != "cpu":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
outputs = self._model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
temperature=self.temperature,
top_p=self.top_p,
do_sample=True,
pad_token_id=self._tokenizer.pad_token_id,
eos_token_id=self._tokenizer.eos_token_id,
)
# Decode
generated_text = self._tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
# Create response
message = AIMessage(content=generated_text)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
return "dapt_llama"
def bind_tools(self, tools: List[Any], **kwargs: Any):
"""Bind tools - returns a runnable that handles tool calling"""
# Store tools for potential use in prompt enhancement
self._bound_tools = tools
return DAPTLlamaWithTools(self, tools)
def _invoke(self, messages: List[BaseMessage], **kwargs: Any) -> AIMessage:
"""Internal invoke that returns AIMessage with tool_calls attribute"""
chat_result = self._generate(messages, **kwargs)
ai_message = chat_result.generations[0].message
# Add tool_calls attribute (empty list for now - DAPT model doesn't natively support tool calling)
# The analyst node will check len(tool_calls) == 0 and use content directly
if not hasattr(ai_message, 'tool_calls'):
ai_message.tool_calls = []
return ai_message
class DAPTLlamaWithTools(Runnable):
"""Wrapper to make DAPT LLM compatible with bind_tools interface"""
def __init__(self, llm: DAPTLlamaChatModel, tools: List[Any]):
self.llm = llm
self.tools = tools
def invoke(self, input: Any, config: Optional[Any] = None, **kwargs: Any) -> AIMessage:
"""Invoke the LLM and return result with tool_calls attribute"""
if isinstance(input, dict) and "messages" in input:
messages = input["messages"]
elif isinstance(input, list):
messages = input
else:
messages = [input] if isinstance(input, BaseMessage) else [HumanMessage(content=str(input))]
return self.llm._invoke(messages, **kwargs)