add vllm support
This commit is contained in:
parent
f047f26df0
commit
34f4abf806
|
|
@ -4,3 +4,7 @@ GOOGLE_API_KEY=
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
XAI_API_KEY=
|
XAI_API_KEY=
|
||||||
OPENROUTER_API_KEY=
|
OPENROUTER_API_KEY=
|
||||||
|
|
||||||
|
# vllm Configuration (optional)
|
||||||
|
VLLM_API_BASE=http://localhost:8000/v1
|
||||||
|
VLLM_API_KEY=
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ DEFAULT_CONFIG = {
|
||||||
# Provider-specific thinking configuration
|
# Provider-specific thinking configuration
|
||||||
"google_thinking_level": None, # "high", "minimal", etc.
|
"google_thinking_level": None, # "high", "minimal", etc.
|
||||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||||
|
"vllm_api_base": None, # vllm API base URL (defaults to http://localhost:8000/v1)
|
||||||
# Debate and discussion settings
|
# Debate and discussion settings
|
||||||
"max_debate_rounds": 1,
|
"max_debate_rounds": 1,
|
||||||
"max_risk_discuss_rounds": 1,
|
"max_risk_discuss_rounds": 1,
|
||||||
|
|
|
||||||
|
|
@ -145,6 +145,13 @@ class TradingAgentsGraph:
|
||||||
if reasoning_effort:
|
if reasoning_effort:
|
||||||
kwargs["reasoning_effort"] = reasoning_effort
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
|
|
||||||
|
elif provider == "vllm":
|
||||||
|
# vllm specific settings
|
||||||
|
api_base = self.config.get("vllm_api_base")
|
||||||
|
if api_base:
|
||||||
|
kwargs["base_url"] = api_base
|
||||||
|
# Add any vllm-specific parameters here if needed
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from .base_client import BaseLLMClient
|
||||||
from .openai_client import OpenAIClient
|
from .openai_client import OpenAIClient
|
||||||
from .anthropic_client import AnthropicClient
|
from .anthropic_client import AnthropicClient
|
||||||
from .google_client import GoogleClient
|
from .google_client import GoogleClient
|
||||||
|
from .vllm_client import VLLMClient
|
||||||
|
|
||||||
|
|
||||||
def create_llm_client(
|
def create_llm_client(
|
||||||
|
|
@ -15,7 +16,7 @@ def create_llm_client(
|
||||||
"""Create an LLM client for the specified provider.
|
"""Create an LLM client for the specified provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm)
|
||||||
model: Model name/identifier
|
model: Model name/identifier
|
||||||
base_url: Optional base URL for API endpoint
|
base_url: Optional base URL for API endpoint
|
||||||
**kwargs: Additional provider-specific arguments
|
**kwargs: Additional provider-specific arguments
|
||||||
|
|
@ -40,4 +41,7 @@ def create_llm_client(
|
||||||
if provider_lower == "google":
|
if provider_lower == "google":
|
||||||
return GoogleClient(model, base_url, **kwargs)
|
return GoogleClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
if provider_lower == "vllm":
|
||||||
|
return VLLMClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
import os
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMClient(BaseLLMClient):
|
||||||
|
"""Client for vllm models. Uses OpenAI-compatible API."""
|
||||||
|
|
||||||
|
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
"""Return configured ChatOpenAI instance for vllm."""
|
||||||
|
llm_kwargs = {
|
||||||
|
"model": self.model,
|
||||||
|
"base_url": self.base_url or os.environ.get("VLLM_API_BASE", "http://localhost:8000/v1"),
|
||||||
|
"api_key": self.kwargs.get("api_key") or os.environ.get("VLLM_API_KEY") or "vllm",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add supported parameters
|
||||||
|
for key in ("temperature", "top_p", "max_tokens", "timeout", "max_retries", "callbacks"):
|
||||||
|
if key in self.kwargs:
|
||||||
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
return ChatOpenAI(**llm_kwargs)
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
"""Validate model name is provided."""
|
||||||
|
return bool(self.model and self.model.strip())
|
||||||
Loading…
Reference in New Issue