add vllm support

This commit is contained in:
xjx 2026-03-06 15:34:57 +08:00
parent f047f26df0
commit 34f4abf806
5 changed files with 49 additions and 1 deletions

View File

@ -4,3 +4,7 @@ GOOGLE_API_KEY=
ANTHROPIC_API_KEY=
XAI_API_KEY=
OPENROUTER_API_KEY=
# vllm Configuration (optional)
VLLM_API_BASE=http://localhost:8000/v1
VLLM_API_KEY=

View File

@ -15,6 +15,7 @@ DEFAULT_CONFIG = {
# Provider-specific thinking configuration
"google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low"
"vllm_api_base": None, # vllm API base URL (defaults to http://localhost:8000/v1)
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,

View File

@ -145,6 +145,13 @@ class TradingAgentsGraph:
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
elif provider == "vllm":
# vllm specific settings
api_base = self.config.get("vllm_api_base")
if api_base:
kwargs["base_url"] = api_base
# Add any vllm-specific parameters here if needed
return kwargs
def _create_tool_nodes(self) -> Dict[str, ToolNode]:

View File

@ -4,6 +4,7 @@ from .base_client import BaseLLMClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
from .google_client import GoogleClient
from .vllm_client import VLLMClient
def create_llm_client(
@ -15,7 +16,7 @@ def create_llm_client(
"""Create an LLM client for the specified provider.
Args:
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm)
model: Model name/identifier
base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments
@ -40,4 +41,7 @@ def create_llm_client(
if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs)
if provider_lower == "vllm":
return VLLMClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@ -0,0 +1,32 @@
import os
from typing import Any, Optional
from langchain_openai import ChatOpenAI
from .base_client import BaseLLMClient
class VLLMClient(BaseLLMClient):
"""Client for vllm models. Uses OpenAI-compatible API."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return configured ChatOpenAI instance for vllm."""
llm_kwargs = {
"model": self.model,
"base_url": self.base_url or os.environ.get("VLLM_API_BASE", "http://localhost:8000/v1"),
"api_key": self.kwargs.get("api_key") or os.environ.get("VLLM_API_KEY") or "vllm",
}
# Add supported parameters
for key in ("temperature", "top_p", "max_tokens", "timeout", "max_retries", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
return ChatOpenAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model name is provided."""
return bool(self.model and self.model.strip())