diff --git a/tradingagents/agents/utils/dapt_llm.py b/tradingagents/agents/utils/dapt_llm.py new file mode 100644 index 00000000..6367f8f9 --- /dev/null +++ b/tradingagents/agents/utils/dapt_llm.py @@ -0,0 +1,223 @@ +""" +LangChain wrapper for DAPTed Llama 3.1 8B model (PEFT adapter) +""" +import torch +from typing import List, Optional, Any, Dict +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.runnables import Runnable +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import PeftModel +import os + + +class DAPTLlamaChatModel(BaseChatModel): + """LangChain-compatible wrapper for DAPTed Llama 3.1 8B model""" + + model_id: str = "meta-llama/Llama-3.1-8B" + dapt_adapter_path: str + device: Optional[str] = None + max_new_tokens: int = 512 + temperature: float = 0.7 + top_p: float = 0.9 + + _model: Optional[Any] = None + _tokenizer: Optional[Any] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._load_model() + + def _load_model(self): + """Load the DAPTed model with PEFT adapters""" + if self._model is not None: + return # Already loaded + + print(f"Loading DAPTed Llama model from {self.dapt_adapter_path}...") + + # Detect device + if self.device is None: + if torch.cuda.is_available(): + self.device = "cuda" + elif torch.backends.mps.is_available(): + self.device = "mps" + else: + self.device = "cpu" + + # Setup quantization for CUDA + bnb_config = None + if self.device == "cuda": + try: + from transformers import BitsAndBytesConfig + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, + bnb_4bit_use_double_quant=True, + ) + except ImportError: + print("bitsandbytes not available, loading in full precision") + + # Determine torch dtype + if self.device == "cuda": + torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + torch_dtype = torch.float32 + + # Get HF token + hf_token = ( + os.getenv("HF_TOKEN") + or os.getenv("HUGGINGFACE_HUB_TOKEN") + or os.getenv("HUGGING_FACE_HUB_TOKEN") + ) + + # Load base model + base_model = AutoModelForCausalLM.from_pretrained( + self.model_id, + device_map="auto" if self.device == "cuda" else None, + torch_dtype=torch_dtype, + quantization_config=bnb_config, + low_cpu_mem_usage=True, + token=hf_token, + ) + + # Load PEFT adapters + if not os.path.exists(self.dapt_adapter_path): + raise ValueError(f"DAPT adapter path not found: {self.dapt_adapter_path}") + + self._model = PeftModel.from_pretrained(base_model, self.dapt_adapter_path) + self._model.eval() + + # Load tokenizer + self._tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=hf_token) + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token + + print(f"DAPTed model loaded successfully on {self.device}") + + def _format_messages(self, messages: List[BaseMessage]) -> str: + """Convert LangChain messages to Llama 3.1 chat format""" + # Use tokenizer's chat template if available + if hasattr(self._tokenizer, 'apply_chat_template') and self._tokenizer.chat_template: + # Convert to format expected by tokenizer + formatted_messages = [] + for msg in messages: + if isinstance(msg, SystemMessage): + formatted_messages.append({"role": "system", "content": msg.content}) + elif isinstance(msg, HumanMessage): + formatted_messages.append({"role": "user", "content": msg.content}) + elif isinstance(msg, AIMessage): + formatted_messages.append({"role": "assistant", "content": msg.content}) + else: + formatted_messages.append({"role": "user", "content": str(msg.content)}) + + # Apply chat template + prompt = self._tokenizer.apply_chat_template( + formatted_messages, + tokenize=False, + add_generation_prompt=True + ) + return prompt + else: + # Fallback to manual formatting + formatted_parts = [] + for msg in messages: + if isinstance(msg, SystemMessage): + formatted_parts.append(f"<|system|>\n{msg.content}<|end|>\n") + elif isinstance(msg, HumanMessage): + formatted_parts.append(f"<|user|>\n{msg.content}<|end|>\n") + elif isinstance(msg, AIMessage): + formatted_parts.append(f"<|assistant|>\n{msg.content}<|end|>\n") + else: + formatted_parts.append(f"{msg.content}\n") + + # Add assistant prompt + formatted_parts.append("<|assistant|>\n") + return "".join(formatted_parts) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Generate a response from the model""" + if self._model is None or self._tokenizer is None: + self._load_model() + + # Format messages + prompt = self._format_messages(messages) + + # Tokenize + inputs = self._tokenizer(prompt, return_tensors="pt") + if self.device != "cpu": + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Generate + with torch.no_grad(): + outputs = self._model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + temperature=self.temperature, + top_p=self.top_p, + do_sample=True, + pad_token_id=self._tokenizer.pad_token_id, + eos_token_id=self._tokenizer.eos_token_id, + ) + + # Decode + generated_text = self._tokenizer.decode( + outputs[0][inputs["input_ids"].shape[1]:], + skip_special_tokens=True + ) + + # Create response + message = AIMessage(content=generated_text) + generation = ChatGeneration(message=message) + + return ChatResult(generations=[generation]) + + @property + def _llm_type(self) -> str: + return "dapt_llama" + + def bind_tools(self, tools: List[Any], **kwargs: Any): + """Bind tools - returns a runnable that handles tool calling""" + # Store tools for potential use in prompt enhancement + self._bound_tools = tools + return DAPTLlamaWithTools(self, tools) + + def _invoke(self, messages: List[BaseMessage], **kwargs: Any) -> AIMessage: + """Internal invoke that returns AIMessage with tool_calls attribute""" + chat_result = self._generate(messages, **kwargs) + ai_message = chat_result.generations[0].message + + # Add tool_calls attribute (empty list for now - DAPT model doesn't natively support tool calling) + # The analyst node will check len(tool_calls) == 0 and use content directly + if not hasattr(ai_message, 'tool_calls'): + ai_message.tool_calls = [] + + return ai_message + + +class DAPTLlamaWithTools(Runnable): + """Wrapper to make DAPT LLM compatible with bind_tools interface""" + + def __init__(self, llm: DAPTLlamaChatModel, tools: List[Any]): + self.llm = llm + self.tools = tools + + def invoke(self, input: Any, config: Optional[Any] = None, **kwargs: Any) -> AIMessage: + """Invoke the LLM and return result with tool_calls attribute""" + if isinstance(input, dict) and "messages" in input: + messages = input["messages"] + elif isinstance(input, list): + messages = input + else: + messages = [input] if isinstance(input, BaseMessage) else [HumanMessage(content=str(input))] + + return self.llm._invoke(messages, **kwargs) + diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 61331251..838416f2 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -22,6 +22,12 @@ DEFAULT_CONFIG = { "quick_think_llm": "gpt-4o-mini", "backend_url": "https://api.openai.com/v1", "openai_api_key": os.getenv("OPENAI_API_KEY"), # Load from .env file + # Sentiment analysis model (DAPTed Llama 3.1 8B) + "use_dapt_sentiment": True, # Use DAPTed model for sentiment analysis (set False to use OpenAI backup) + # Path to DAPT PEFT adapter (dynamically uses current username) + "dapt_adapter_path": f"/u/v/d/{os.getenv('USER', 'negi3')}/llama3_8b_dapt_transcripts_lora", + # Fallback: OpenAI model if DAPT is unavailable + "sentiment_fallback_llm": "o1-mini", # OpenAI model for fallback # Debate and discussion settings "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index b270ffc0..accf083c 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -2,6 +2,7 @@ from typing import Dict, Any from langchain_openai import ChatOpenAI +from langchain_core.language_models.chat_models import BaseChatModel from langgraph.graph import END, StateGraph, START from langgraph.prebuilt import ToolNode @@ -18,6 +19,7 @@ class GraphSetup: self, quick_thinking_llm: ChatOpenAI, deep_thinking_llm: ChatOpenAI, + sentiment_llm: BaseChatModel, tool_nodes: Dict[str, ToolNode], bull_memory, bear_memory, @@ -29,6 +31,7 @@ class GraphSetup: """Initialize with required components.""" self.quick_thinking_llm = quick_thinking_llm self.deep_thinking_llm = deep_thinking_llm + self.sentiment_llm = sentiment_llm self.tool_nodes = tool_nodes self.bull_memory = bull_memory self.bear_memory = bear_memory @@ -73,7 +76,7 @@ class GraphSetup: if "news" in selected_analysts: analyst_nodes["news"] = create_news_analyst( - self.quick_thinking_llm + self.sentiment_llm ) delete_nodes["news"] = create_msg_delete() tool_nodes["news"] = self.tool_nodes["news"] diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 01ab3fee..d1de495f 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -9,6 +9,7 @@ from typing import Dict, Any, Tuple, List, Optional from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI +from tradingagents.agents.utils.dapt_llm import DAPTLlamaChatModel from langgraph.prebuilt import ToolNode @@ -92,6 +93,34 @@ class TradingAgentsGraph: else: raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}") + # Initialize sentiment analysis LLM (DAPTed Llama 3.1 8B) + if self.config.get("use_dapt_sentiment", True): + # Use DAPTed model directly + dapt_path = self.config.get("dapt_adapter_path", f"/u/v/d/{os.getenv('USER', 'negi3')}/llama3_8b_dapt_transcripts_lora") + # Convert relative path to absolute if needed + if not os.path.isabs(dapt_path): + dapt_path = os.path.join(self.config["project_dir"], dapt_path) + try: + self.sentiment_llm = DAPTLlamaChatModel( + dapt_adapter_path=dapt_path, + max_new_tokens=512, + temperature=0.7, + ) + except Exception as e: + print(f"Warning: Failed to load DAPT model: {e}. Falling back to OpenAI.") + self.sentiment_llm = ChatOpenAI( + model=self.config.get("sentiment_fallback_llm", "o1-mini"), + base_url=self.config["backend_url"], + api_key=self.config.get("openai_api_key") + ) + else: + # Fallback to OpenAI model + self.sentiment_llm = ChatOpenAI( + model=self.config.get("sentiment_fallback_llm", "o1-mini"), + base_url=self.config["backend_url"], + api_key=self.config.get("openai_api_key") + ) + # Initialize memories self.bull_memory = FinancialSituationMemory("bull_memory", self.config) self.bear_memory = FinancialSituationMemory("bear_memory", self.config) @@ -107,6 +136,7 @@ class TradingAgentsGraph: self.graph_setup = GraphSetup( self.quick_thinking_llm, self.deep_thinking_llm, + self.sentiment_llm, self.tool_nodes, self.bull_memory, self.bear_memory,