diff --git a/.env.example b/.env.example index 1328b838..fd1468e2 100644 --- a/.env.example +++ b/.env.example @@ -4,3 +4,7 @@ GOOGLE_API_KEY= ANTHROPIC_API_KEY= XAI_API_KEY= OPENROUTER_API_KEY= + +# vllm Configuration (optional) +VLLM_API_BASE=http://localhost:8000/v1 +VLLM_API_KEY= diff --git a/cli/main.py b/cli/main.py index adda48fc..10779fb5 100644 --- a/cli/main.py +++ b/cli/main.py @@ -556,6 +556,8 @@ def get_user_selections(): # Step 7: Provider-specific thinking configuration thinking_level = None reasoning_effort = None + vllm_api_base = None + vllm_api_key = None provider_lower = selected_llm_provider.lower() if provider_lower == "google": @@ -574,6 +576,14 @@ def get_user_selections(): ) ) reasoning_effort = ask_openai_reasoning_effort() + elif provider_lower == "vllm": + console.print( + create_question_box( + "Step 7 : vLLM Configuration", + "Configure vLLM API configuration" + ) + ) + vllm_api_base, vllm_api_key = ask_vllm_config() return { "ticker": selected_ticker, @@ -586,6 +596,8 @@ def get_user_selections(): "deep_thinker": selected_deep_thinker, "google_thinking_level": thinking_level, "openai_reasoning_effort": reasoning_effort, + "vllm_api_base": vllm_api_base, + "vllm_api_key": vllm_api_key, } @@ -911,7 +923,11 @@ def run_analysis(): # Provider-specific thinking configuration config["google_thinking_level"] = selections.get("google_thinking_level") config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") - + if selections["llm_provider"] == "vllm": + if selections.get("vllm_api_base"): + config["backend_url"] = selections["vllm_api_base"] + config["vllm_api_base"] = selections.get("vllm_api_base") + config["vllm_api_key"] = selections.get("vllm_api_key") # Create stats callback handler for tracking LLM/tool calls stats_handler = StatsCallbackHandler() diff --git a/cli/utils.py b/cli/utils.py index 5a8ec16c..a0d36108 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -164,6 +164,7 @@ def select_shallow_thinking_agent(provider) -> str: ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), ], + "vllm": [("Qwen/Qwen3.5-2B", "Qwen/Qwen3.5-2B")] } choice = questionary.select( @@ -231,6 +232,7 @@ def select_deep_thinking_agent(provider) -> str: ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), ("Qwen3:latest (8B, local)", "qwen3:latest"), ], + "vllm": [("Qwen/Qwen3.5-2B", "Qwen/Qwen3.5-2B")] } choice = questionary.select( @@ -265,6 +267,7 @@ def select_llm_provider() -> tuple[str, str]: ("xAI", "https://api.x.ai/v1"), ("Openrouter", "https://openrouter.ai/api/v1"), ("Ollama", "http://localhost:11434/v1"), + ("vLLM", "http://localhost:8000/v1") ] choice = questionary.select( @@ -329,3 +332,21 @@ def ask_gemini_thinking_config() -> str | None: ("pointer", "fg:green noinherit"), ]), ).ask() + + +def ask_vllm_config() -> tuple[str | None, str | None]: + """Ask for VLLM configuration. """ + import os + default_base = os.environ.get("VLLM_API_BASE", "https://localhost:8000/v1") + api_base = questionary.text("Enter VLLM API URL:", default=default_base, style=questionary.Style( + [ + ("text", "fg:green"), + ("highlighted", "noinherit"), + ])).ask() + default_api_key = os.environ.get("VLLM_API_KEY", "") + api_key = questionary.text("Enter VLLM API Key:", default=default_api_key, style=questionary.Style( + [ + ("text", "fg:green"), + ("highlighted", "noinherit"), + ])).ask() + return api_base.strip() if api_base else None, api_key.strip() if api_key else None diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index ecf0dc29..305ed4b3 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -15,6 +15,7 @@ DEFAULT_CONFIG = { # Provider-specific thinking configuration "google_thinking_level": None, # "high", "minimal", etc. "openai_reasoning_effort": None, # "medium", "high", "low" + "vllm_api_base": None, # vllm API base URL (defaults to http://localhost:8000/v1) # Debate and discussion settings "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index c7ef0f98..0e143f8a 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -78,16 +78,21 @@ class TradingAgentsGraph: if self.callbacks: llm_kwargs["callbacks"] = self.callbacks + provider = self.config["llm_provider"].lower() + if provider == "vllm": + base_url = self.config.get("vllm_api_base") or self.config["backend_url"] + else: + base_url = self.config["backend_url"] deep_client = create_llm_client( provider=self.config["llm_provider"], model=self.config["deep_think_llm"], - base_url=self.config.get("backend_url"), + base_url=base_url, **llm_kwargs, ) quick_client = create_llm_client( provider=self.config["llm_provider"], model=self.config["quick_think_llm"], - base_url=self.config.get("backend_url"), + base_url=base_url, **llm_kwargs, ) @@ -148,6 +153,13 @@ class TradingAgentsGraph: if reasoning_effort: kwargs["reasoning_effort"] = reasoning_effort + elif provider == "vllm": + # vllm specific settings + api_key = self.config.get("vllm_api_key") + if api_key: + kwargs["api_key"] = api_key + # Add any vllm-specific parameters here if needed + return kwargs def _create_tool_nodes(self) -> Dict[str, ToolNode]: diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 93c2a7d3..57b3ef77 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -4,6 +4,7 @@ from .base_client import BaseLLMClient from .openai_client import OpenAIClient from .anthropic_client import AnthropicClient from .google_client import GoogleClient +from .vllm_client import VLLMClient def create_llm_client( @@ -15,7 +16,7 @@ def create_llm_client( """Create an LLM client for the specified provider. Args: - provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter) + provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm) model: Model name/identifier base_url: Optional base URL for API endpoint **kwargs: Additional provider-specific arguments @@ -46,4 +47,7 @@ def create_llm_client( if provider_lower == "google": return GoogleClient(model, base_url, **kwargs) + if provider_lower == "vllm": + return VLLMClient(model, base_url, **kwargs) + raise ValueError(f"Unsupported LLM provider: {provider}") diff --git a/tradingagents/llm_clients/validators.py b/tradingagents/llm_clients/validators.py index 1e2388b3..144c807d 100644 --- a/tradingagents/llm_clients/validators.py +++ b/tradingagents/llm_clients/validators.py @@ -54,11 +54,11 @@ VALID_MODELS = { def validate_model(provider: str, model: str) -> bool: """Check if model name is valid for the given provider. - For ollama, openrouter - any model is accepted. + For ollama, openrouter, vllm - any model is accepted. """ provider_lower = provider.lower() - if provider_lower in ("ollama", "openrouter"): + if provider_lower in ("ollama", "openrouter", "vllm"): return True if provider_lower not in VALID_MODELS: diff --git a/tradingagents/llm_clients/vllm_client.py b/tradingagents/llm_clients/vllm_client.py new file mode 100644 index 00000000..2779313b --- /dev/null +++ b/tradingagents/llm_clients/vllm_client.py @@ -0,0 +1,32 @@ +import os +from typing import Any, Optional + +from langchain_openai import ChatOpenAI + +from .base_client import BaseLLMClient + + +class VLLMClient(BaseLLMClient): + """Client for vllm models. Uses OpenAI-compatible API.""" + + def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + super().__init__(model, base_url, **kwargs) + + def get_llm(self) -> Any: + """Return configured ChatOpenAI instance for vllm.""" + llm_kwargs = { + "model": self.model, + "base_url": self.base_url or os.environ.get("VLLM_API_BASE", "http://localhost:8000/v1"), + "api_key": self.kwargs.get("api_key") or os.environ.get("VLLM_API_KEY") or "vllm", + } + + # Add supported parameters + for key in ("temperature", "top_p", "max_tokens", "timeout", "max_retries", "callbacks"): + if key in self.kwargs: + llm_kwargs[key] = self.kwargs[key] + + return ChatOpenAI(**llm_kwargs) + + def validate_model(self) -> bool: + """Validate model name is provided.""" + return bool(self.model and self.model.strip())