add vllm support
This commit is contained in:
parent
f047f26df0
commit
34f4abf806
|
|
@ -4,3 +4,7 @@ GOOGLE_API_KEY=
|
|||
ANTHROPIC_API_KEY=
|
||||
XAI_API_KEY=
|
||||
OPENROUTER_API_KEY=
|
||||
|
||||
# vllm Configuration (optional)
|
||||
VLLM_API_BASE=http://localhost:8000/v1
|
||||
VLLM_API_KEY=
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ DEFAULT_CONFIG = {
|
|||
# Provider-specific thinking configuration
|
||||
"google_thinking_level": None, # "high", "minimal", etc.
|
||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||
"vllm_api_base": None, # vllm API base URL (defaults to http://localhost:8000/v1)
|
||||
# Debate and discussion settings
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
|
|
|
|||
|
|
@ -145,6 +145,13 @@ class TradingAgentsGraph:
|
|||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
|
||||
elif provider == "vllm":
|
||||
# vllm specific settings
|
||||
api_base = self.config.get("vllm_api_base")
|
||||
if api_base:
|
||||
kwargs["base_url"] = api_base
|
||||
# Add any vllm-specific parameters here if needed
|
||||
|
||||
return kwargs
|
||||
|
||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from .base_client import BaseLLMClient
|
|||
from .openai_client import OpenAIClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
from .google_client import GoogleClient
|
||||
from .vllm_client import VLLMClient
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
|
|
@ -15,7 +16,7 @@ def create_llm_client(
|
|||
"""Create an LLM client for the specified provider.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm)
|
||||
model: Model name/identifier
|
||||
base_url: Optional base URL for API endpoint
|
||||
**kwargs: Additional provider-specific arguments
|
||||
|
|
@ -40,4 +41,7 @@ def create_llm_client(
|
|||
if provider_lower == "google":
|
||||
return GoogleClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "vllm":
|
||||
return VLLMClient(model, base_url, **kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,32 @@
|
|||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
|
||||
|
||||
class VLLMClient(BaseLLMClient):
|
||||
"""Client for vllm models. Uses OpenAI-compatible API."""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatOpenAI instance for vllm."""
|
||||
llm_kwargs = {
|
||||
"model": self.model,
|
||||
"base_url": self.base_url or os.environ.get("VLLM_API_BASE", "http://localhost:8000/v1"),
|
||||
"api_key": self.kwargs.get("api_key") or os.environ.get("VLLM_API_KEY") or "vllm",
|
||||
}
|
||||
|
||||
# Add supported parameters
|
||||
for key in ("temperature", "top_p", "max_tokens", "timeout", "max_retries", "callbacks"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
return ChatOpenAI(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Validate model name is provided."""
|
||||
return bool(self.model and self.model.strip())
|
||||
Loading…
Reference in New Issue