Applied DAPT model to tradingagent graph

This commit is contained in:
Shashwat Negi 2025-11-09 19:56:32 -06:00
parent c4b0aa6ec9
commit e3952edf91
4 changed files with 263 additions and 1 deletions

View File

@ -0,0 +1,223 @@
"""
LangChain wrapper for DAPTed Llama 3.1 8B model (PEFT adapter)
"""
import torch
from typing import List, Optional, Any, Dict
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.runnables import Runnable
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os
class DAPTLlamaChatModel(BaseChatModel):
"""LangChain-compatible wrapper for DAPTed Llama 3.1 8B model"""
model_id: str = "meta-llama/Llama-3.1-8B"
dapt_adapter_path: str
device: Optional[str] = None
max_new_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.9
_model: Optional[Any] = None
_tokenizer: Optional[Any] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._load_model()
def _load_model(self):
"""Load the DAPTed model with PEFT adapters"""
if self._model is not None:
return # Already loaded
print(f"Loading DAPTed Llama model from {self.dapt_adapter_path}...")
# Detect device
if self.device is None:
if torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
# Setup quantization for CUDA
bnb_config = None
if self.device == "cuda":
try:
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
bnb_4bit_use_double_quant=True,
)
except ImportError:
print("bitsandbytes not available, loading in full precision")
# Determine torch dtype
if self.device == "cuda":
torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
torch_dtype = torch.float32
# Get HF token
hf_token = (
os.getenv("HF_TOKEN")
or os.getenv("HUGGINGFACE_HUB_TOKEN")
or os.getenv("HUGGING_FACE_HUB_TOKEN")
)
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto" if self.device == "cuda" else None,
torch_dtype=torch_dtype,
quantization_config=bnb_config,
low_cpu_mem_usage=True,
token=hf_token,
)
# Load PEFT adapters
if not os.path.exists(self.dapt_adapter_path):
raise ValueError(f"DAPT adapter path not found: {self.dapt_adapter_path}")
self._model = PeftModel.from_pretrained(base_model, self.dapt_adapter_path)
self._model.eval()
# Load tokenizer
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=hf_token)
if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token
print(f"DAPTed model loaded successfully on {self.device}")
def _format_messages(self, messages: List[BaseMessage]) -> str:
"""Convert LangChain messages to Llama 3.1 chat format"""
# Use tokenizer's chat template if available
if hasattr(self._tokenizer, 'apply_chat_template') and self._tokenizer.chat_template:
# Convert to format expected by tokenizer
formatted_messages = []
for msg in messages:
if isinstance(msg, SystemMessage):
formatted_messages.append({"role": "system", "content": msg.content})
elif isinstance(msg, HumanMessage):
formatted_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
formatted_messages.append({"role": "assistant", "content": msg.content})
else:
formatted_messages.append({"role": "user", "content": str(msg.content)})
# Apply chat template
prompt = self._tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True
)
return prompt
else:
# Fallback to manual formatting
formatted_parts = []
for msg in messages:
if isinstance(msg, SystemMessage):
formatted_parts.append(f"<|system|>\n{msg.content}<|end|>\n")
elif isinstance(msg, HumanMessage):
formatted_parts.append(f"<|user|>\n{msg.content}<|end|>\n")
elif isinstance(msg, AIMessage):
formatted_parts.append(f"<|assistant|>\n{msg.content}<|end|>\n")
else:
formatted_parts.append(f"{msg.content}\n")
# Add assistant prompt
formatted_parts.append("<|assistant|>\n")
return "".join(formatted_parts)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response from the model"""
if self._model is None or self._tokenizer is None:
self._load_model()
# Format messages
prompt = self._format_messages(messages)
# Tokenize
inputs = self._tokenizer(prompt, return_tensors="pt")
if self.device != "cpu":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
outputs = self._model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
temperature=self.temperature,
top_p=self.top_p,
do_sample=True,
pad_token_id=self._tokenizer.pad_token_id,
eos_token_id=self._tokenizer.eos_token_id,
)
# Decode
generated_text = self._tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
# Create response
message = AIMessage(content=generated_text)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
return "dapt_llama"
def bind_tools(self, tools: List[Any], **kwargs: Any):
"""Bind tools - returns a runnable that handles tool calling"""
# Store tools for potential use in prompt enhancement
self._bound_tools = tools
return DAPTLlamaWithTools(self, tools)
def _invoke(self, messages: List[BaseMessage], **kwargs: Any) -> AIMessage:
"""Internal invoke that returns AIMessage with tool_calls attribute"""
chat_result = self._generate(messages, **kwargs)
ai_message = chat_result.generations[0].message
# Add tool_calls attribute (empty list for now - DAPT model doesn't natively support tool calling)
# The analyst node will check len(tool_calls) == 0 and use content directly
if not hasattr(ai_message, 'tool_calls'):
ai_message.tool_calls = []
return ai_message
class DAPTLlamaWithTools(Runnable):
"""Wrapper to make DAPT LLM compatible with bind_tools interface"""
def __init__(self, llm: DAPTLlamaChatModel, tools: List[Any]):
self.llm = llm
self.tools = tools
def invoke(self, input: Any, config: Optional[Any] = None, **kwargs: Any) -> AIMessage:
"""Invoke the LLM and return result with tool_calls attribute"""
if isinstance(input, dict) and "messages" in input:
messages = input["messages"]
elif isinstance(input, list):
messages = input
else:
messages = [input] if isinstance(input, BaseMessage) else [HumanMessage(content=str(input))]
return self.llm._invoke(messages, **kwargs)

View File

@ -22,6 +22,12 @@ DEFAULT_CONFIG = {
"quick_think_llm": "gpt-4o-mini",
"backend_url": "https://api.openai.com/v1",
"openai_api_key": os.getenv("OPENAI_API_KEY"), # Load from .env file
# Sentiment analysis model (DAPTed Llama 3.1 8B)
"use_dapt_sentiment": True, # Use DAPTed model for sentiment analysis (set False to use OpenAI backup)
# Path to DAPT PEFT adapter (dynamically uses current username)
"dapt_adapter_path": f"/u/v/d/{os.getenv('USER', 'negi3')}/llama3_8b_dapt_transcripts_lora",
# Fallback: OpenAI model if DAPT is unavailable
"sentiment_fallback_llm": "o1-mini", # OpenAI model for fallback
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,

View File

@ -2,6 +2,7 @@
from typing import Dict, Any
from langchain_openai import ChatOpenAI
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode
@ -18,6 +19,7 @@ class GraphSetup:
self,
quick_thinking_llm: ChatOpenAI,
deep_thinking_llm: ChatOpenAI,
sentiment_llm: BaseChatModel,
tool_nodes: Dict[str, ToolNode],
bull_memory,
bear_memory,
@ -29,6 +31,7 @@ class GraphSetup:
"""Initialize with required components."""
self.quick_thinking_llm = quick_thinking_llm
self.deep_thinking_llm = deep_thinking_llm
self.sentiment_llm = sentiment_llm
self.tool_nodes = tool_nodes
self.bull_memory = bull_memory
self.bear_memory = bear_memory
@ -73,7 +76,7 @@ class GraphSetup:
if "news" in selected_analysts:
analyst_nodes["news"] = create_news_analyst(
self.quick_thinking_llm
self.sentiment_llm
)
delete_nodes["news"] = create_msg_delete()
tool_nodes["news"] = self.tool_nodes["news"]

View File

@ -9,6 +9,7 @@ from typing import Dict, Any, Tuple, List, Optional
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from tradingagents.agents.utils.dapt_llm import DAPTLlamaChatModel
from langgraph.prebuilt import ToolNode
@ -92,6 +93,34 @@ class TradingAgentsGraph:
else:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
# Initialize sentiment analysis LLM (DAPTed Llama 3.1 8B)
if self.config.get("use_dapt_sentiment", True):
# Use DAPTed model directly
dapt_path = self.config.get("dapt_adapter_path", f"/u/v/d/{os.getenv('USER', 'negi3')}/llama3_8b_dapt_transcripts_lora")
# Convert relative path to absolute if needed
if not os.path.isabs(dapt_path):
dapt_path = os.path.join(self.config["project_dir"], dapt_path)
try:
self.sentiment_llm = DAPTLlamaChatModel(
dapt_adapter_path=dapt_path,
max_new_tokens=512,
temperature=0.7,
)
except Exception as e:
print(f"Warning: Failed to load DAPT model: {e}. Falling back to OpenAI.")
self.sentiment_llm = ChatOpenAI(
model=self.config.get("sentiment_fallback_llm", "o1-mini"),
base_url=self.config["backend_url"],
api_key=self.config.get("openai_api_key")
)
else:
# Fallback to OpenAI model
self.sentiment_llm = ChatOpenAI(
model=self.config.get("sentiment_fallback_llm", "o1-mini"),
base_url=self.config["backend_url"],
api_key=self.config.get("openai_api_key")
)
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
@ -107,6 +136,7 @@ class TradingAgentsGraph:
self.graph_setup = GraphSetup(
self.quick_thinking_llm,
self.deep_thinking_llm,
self.sentiment_llm,
self.tool_nodes,
self.bull_memory,
self.bear_memory,