From 080ac8892f07fed5169316cbc4713d024d867113 Mon Sep 17 00:00:00 2001 From: reopio Date: Thu, 16 Oct 2025 03:32:25 +0000 Subject: [PATCH] Add support for vllm --- cli/utils.py | 115 +++++++++++++++++++-------- main.py | 2 + tradingagents/agents/utils/memory.py | 3 + tradingagents/default_config.py | 1 + tradingagents/graph/trading_graph.py | 32 ++++++-- 5 files changed, 115 insertions(+), 38 deletions(-) diff --git a/cli/utils.py b/cli/utils.py index 7b9682a6..5256b4e6 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -152,24 +152,43 @@ def select_shallow_thinking_agent(provider) -> str: "ollama": [ ("llama3.1 local", "llama3.1"), ("llama3.2 local", "llama3.2"), + ], + "vllm": [ + ("llama3.1 local", "llama3.1"), + ("qwen3", "qwen3"), ] } - choice = questionary.select( - "Select Your [Quick-Thinking LLM Engine]:", - choices=[ - questionary.Choice(display, value=value) - for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()] - ], - instruction="\n- Use arrow keys to navigate\n- Press Enter to select", - style=questionary.Style( - [ - ("selected", "fg:magenta noinherit"), - ("highlighted", "fg:magenta noinherit"), - ("pointer", "fg:magenta noinherit"), - ] - ), - ).ask() + if provider == "vllm": + choice = questionary.text( + "Please input the vllm model name for shallow thinking (default: llama3.1):", + default="llama3.1", + validate=lambda x: len(x.strip()) > 0 or "Please enter a valid model name.", + style=questionary.Style( + [ + ("text", "fg:green"), + ("highlighted", "noinherit"), + ] + ), + ).ask() + + else: + # Use questionary to create an interactive selection menu for shallow thinking LLM engines + choice = questionary.select( + "Select Your [Quick-Thinking LLM Engine]:", + choices=[ + questionary.Choice(display, value=value) + for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()] + ], + instruction="\n- Use arrow keys to navigate\n- Press Enter to select", + style=questionary.Style( + [ + ("selected", "fg:magenta noinherit"), + ("highlighted", "fg:magenta noinherit"), + ("pointer", "fg:magenta noinherit"), + ] + ), + ).ask() if choice is None: console.print( @@ -214,24 +233,43 @@ def select_deep_thinking_agent(provider) -> str: "ollama": [ ("llama3.1 local", "llama3.1"), ("qwen3", "qwen3"), + ], + "vllm": [ + ("llama3.1 local", "llama3.1"), + ("qwen3", "qwen3"), ] } - - choice = questionary.select( - "Select Your [Deep-Thinking LLM Engine]:", - choices=[ - questionary.Choice(display, value=value) - for display, value in DEEP_AGENT_OPTIONS[provider.lower()] - ], - instruction="\n- Use arrow keys to navigate\n- Press Enter to select", - style=questionary.Style( - [ - ("selected", "fg:magenta noinherit"), - ("highlighted", "fg:magenta noinherit"), - ("pointer", "fg:magenta noinherit"), - ] - ), - ).ask() + + if provider == "vllm": + choice = questionary.text( + "Please input the vllm model name for deep thinking (default: llama3.1):", + default="llama3.1", + validate=lambda x: len(x.strip()) > 0 or "Please enter a valid model name.", + style=questionary.Style( + [ + ("text", "fg:green"), + ("highlighted", "noinherit"), + ] + ), + ).ask() + + else: + # Use questionary to create an interactive selection menu for deep thinking LLM engines + choice = questionary.select( + "Select Your [Deep-Thinking LLM Engine]:", + choices=[ + questionary.Choice(display, value=value) + for display, value in DEEP_AGENT_OPTIONS[provider.lower()] + ], + instruction="\n- Use arrow keys to navigate\n- Press Enter to select", + style=questionary.Style( + [ + ("selected", "fg:magenta noinherit"), + ("highlighted", "fg:magenta noinherit"), + ("pointer", "fg:magenta noinherit"), + ] + ), + ).ask() if choice is None: console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]") @@ -247,7 +285,8 @@ def select_llm_provider() -> tuple[str, str]: ("Anthropic", "https://api.anthropic.com/"), ("Google", "https://generativelanguage.googleapis.com/v1"), ("Openrouter", "https://openrouter.ai/api/v1"), - ("Ollama", "http://localhost:11434/v1"), + ("Ollama", "http://localhost:11434/v1"), + ("vllm", "http://localhost:8000/v1"), ] choice = questionary.select( @@ -271,6 +310,18 @@ def select_llm_provider() -> tuple[str, str]: exit(1) display_name, url = choice + if display_name == "vllm": + url = questionary.text( + "Please input the vllm api url (default: http://localhost:8000/v1):", + default="http://localhost:8000/v1", + validate=lambda x: len(x.strip()) > 0 or "Please enter a valid URL.", + style=questionary.Style( + [ + ("text", "fg:green"), + ("highlighted", "noinherit"), + ] + ), + ).ask() print(f"You selected: {display_name}\tURL: {url}") return display_name, url diff --git a/main.py b/main.py index a85ee6ec..52c01f75 100644 --- a/main.py +++ b/main.py @@ -20,6 +20,8 @@ config["data_vendors"] = { "news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local } +config["embeddings"] = "text-embedding-3-small" + # Initialize with custom config ta = TradingAgentsGraph(debug=True, config=config) diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 69b8ab8c..d34bcd44 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,12 +1,15 @@ import chromadb from chromadb.config import Settings from openai import OpenAI +import questionary class FinancialSituationMemory: def __init__(self, name, config): if config["backend_url"] == "http://localhost:11434/v1": self.embedding = "nomic-embed-text" + elif config["llm_provider"] == "vllm": + self.embedding = config["embeddings"] else: self.embedding = "text-embedding-3-small" self.client = OpenAI(base_url=config["backend_url"]) diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 1f40a2a2..89ce97c1 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -30,4 +30,5 @@ DEFAULT_CONFIG = { # Example: "get_stock_data": "alpha_vantage", # Override category default # Example: "get_news": "openai", # Override category default }, + "embeddings": "text-embedding-3-small", # Options: text-embedding-3-small, nomic-embed-text, vllm model name } diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 40cdff75..26c8d760 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -6,6 +6,8 @@ import json from datetime import date from typing import Dict, Any, Tuple, List, Optional +import questionary + from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI @@ -72,7 +74,7 @@ class TradingAgentsGraph: ) # Initialize LLMs - if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": + if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter" or self.config["llm_provider"] == "vllm": self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) elif self.config["llm_provider"].lower() == "anthropic": @@ -85,11 +87,29 @@ class TradingAgentsGraph: raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}") # Initialize memories - self.bull_memory = FinancialSituationMemory("bull_memory", self.config) - self.bear_memory = FinancialSituationMemory("bear_memory", self.config) - self.trader_memory = FinancialSituationMemory("trader_memory", self.config) - self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config) - self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config) + if self.config["llm_provider"] == "vllm": + questionary.text( + "Please input the vllm embedding model name (default: None):", + default="None", + validate=lambda x: len(x.strip()) > 0 or "Please enter a valid embedding model name.", + style=questionary.Style( + [ + ("text", "fg:green"), + ("highlighted", "noinherit"), + ] + ), + ).ask() + self.bull_memory = FinancialSituationMemory("bull_memory", self.config) + self.bear_memory = FinancialSituationMemory("bear_memory", self.config) + self.trader_memory = FinancialSituationMemory("trader_memory", self.config) + self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config) + self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config) + else: + self.bull_memory = FinancialSituationMemory("bull_memory", self.config) + self.bear_memory = FinancialSituationMemory("bear_memory", self.config) + self.trader_memory = FinancialSituationMemory("trader_memory", self.config) + self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config) + self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config) # Create tool nodes self.tool_nodes = self._create_tool_nodes()