diff --git a/.env.example b/.env.example index 1328b838..fd1468e2 100644 --- a/.env.example +++ b/.env.example @@ -4,3 +4,7 @@ GOOGLE_API_KEY= ANTHROPIC_API_KEY= XAI_API_KEY= OPENROUTER_API_KEY= + +# vllm Configuration (optional) +VLLM_API_BASE=http://localhost:8000/v1 +VLLM_API_KEY= diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index ecf0dc29..305ed4b3 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -15,6 +15,7 @@ DEFAULT_CONFIG = { # Provider-specific thinking configuration "google_thinking_level": None, # "high", "minimal", etc. "openai_reasoning_effort": None, # "medium", "high", "low" + "vllm_api_base": None, # vllm API base URL (defaults to http://localhost:8000/v1) # Debate and discussion settings "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 44ecca0c..f5eb1902 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -145,6 +145,13 @@ class TradingAgentsGraph: if reasoning_effort: kwargs["reasoning_effort"] = reasoning_effort + elif provider == "vllm": + # vllm specific settings + api_base = self.config.get("vllm_api_base") + if api_base: + kwargs["base_url"] = api_base + # Add any vllm-specific parameters here if needed + return kwargs def _create_tool_nodes(self) -> Dict[str, ToolNode]: diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 028c88a2..e10e83da 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -4,6 +4,7 @@ from .base_client import BaseLLMClient from .openai_client import OpenAIClient from .anthropic_client import AnthropicClient from .google_client import GoogleClient +from .vllm_client import VLLMClient def create_llm_client( @@ -15,7 +16,7 @@ def create_llm_client( """Create an LLM client for the specified provider. Args: - provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter) + provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm) model: Model name/identifier base_url: Optional base URL for API endpoint **kwargs: Additional provider-specific arguments @@ -40,4 +41,7 @@ def create_llm_client( if provider_lower == "google": return GoogleClient(model, base_url, **kwargs) + if provider_lower == "vllm": + return VLLMClient(model, base_url, **kwargs) + raise ValueError(f"Unsupported LLM provider: {provider}") diff --git a/tradingagents/llm_clients/vllm_client.py b/tradingagents/llm_clients/vllm_client.py new file mode 100644 index 00000000..2779313b --- /dev/null +++ b/tradingagents/llm_clients/vllm_client.py @@ -0,0 +1,32 @@ +import os +from typing import Any, Optional + +from langchain_openai import ChatOpenAI + +from .base_client import BaseLLMClient + + +class VLLMClient(BaseLLMClient): + """Client for vllm models. Uses OpenAI-compatible API.""" + + def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + super().__init__(model, base_url, **kwargs) + + def get_llm(self) -> Any: + """Return configured ChatOpenAI instance for vllm.""" + llm_kwargs = { + "model": self.model, + "base_url": self.base_url or os.environ.get("VLLM_API_BASE", "http://localhost:8000/v1"), + "api_key": self.kwargs.get("api_key") or os.environ.get("VLLM_API_KEY") or "vllm", + } + + # Add supported parameters + for key in ("temperature", "top_p", "max_tokens", "timeout", "max_retries", "callbacks"): + if key in self.kwargs: + llm_kwargs[key] = self.kwargs[key] + + return ChatOpenAI(**llm_kwargs) + + def validate_model(self) -> bool: + """Validate model name is provided.""" + return bool(self.model and self.model.strip())