add vllm support
This commit is contained in:
parent
34f4abf806
commit
c1ec37d830
18
cli/main.py
18
cli/main.py
|
|
@ -556,6 +556,8 @@ def get_user_selections():
|
|||
# Step 7: Provider-specific thinking configuration
|
||||
thinking_level = None
|
||||
reasoning_effort = None
|
||||
vllm_api_base = None
|
||||
vllm_api_key = None
|
||||
|
||||
provider_lower = selected_llm_provider.lower()
|
||||
if provider_lower == "google":
|
||||
|
|
@ -574,6 +576,14 @@ def get_user_selections():
|
|||
)
|
||||
)
|
||||
reasoning_effort = ask_openai_reasoning_effort()
|
||||
elif provider_lower == "vllm":
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 7 : vLLM Configuration",
|
||||
"Configure vLLM API configuration"
|
||||
)
|
||||
)
|
||||
vllm_api_base, vllm_api_key = ask_vllm_config()
|
||||
|
||||
return {
|
||||
"ticker": selected_ticker,
|
||||
|
|
@ -586,6 +596,8 @@ def get_user_selections():
|
|||
"deep_thinker": selected_deep_thinker,
|
||||
"google_thinking_level": thinking_level,
|
||||
"openai_reasoning_effort": reasoning_effort,
|
||||
"vllm_api_base": vllm_api_base,
|
||||
"vllm_api_key": vllm_api_key,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -911,7 +923,11 @@ def run_analysis():
|
|||
# Provider-specific thinking configuration
|
||||
config["google_thinking_level"] = selections.get("google_thinking_level")
|
||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||
|
||||
if selections["llm_provider"] == "vllm":
|
||||
if selections.get("vllm_api_base"):
|
||||
config["backend_url"] = selections["vllm_api_base"]
|
||||
config["vllm_api_base"] = selections.get("vllm_api_base")
|
||||
config["vllm_api_key"] = selections.get("vllm_api_key")
|
||||
# Create stats callback handler for tracking LLM/tool calls
|
||||
stats_handler = StatsCallbackHandler()
|
||||
|
||||
|
|
|
|||
21
cli/utils.py
21
cli/utils.py
|
|
@ -160,6 +160,7 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
],
|
||||
"vllm": [("Qwen/Qwen3.5-2B", "Qwen/Qwen3.5-2B")]
|
||||
}
|
||||
|
||||
choice = questionary.select(
|
||||
|
|
@ -228,6 +229,7 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
],
|
||||
"vllm": [("Qwen/Qwen3.5-2B", "Qwen/Qwen3.5-2B")]
|
||||
}
|
||||
|
||||
choice = questionary.select(
|
||||
|
|
@ -262,6 +264,7 @@ def select_llm_provider() -> tuple[str, str]:
|
|||
("xAI", "https://api.x.ai/v1"),
|
||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
("vLLM", "http://localhost:8000/v1")
|
||||
]
|
||||
|
||||
choice = questionary.select(
|
||||
|
|
@ -326,3 +329,21 @@ def ask_gemini_thinking_config() -> str | None:
|
|||
("pointer", "fg:green noinherit"),
|
||||
]),
|
||||
).ask()
|
||||
|
||||
|
||||
def ask_vllm_config() -> tuple[str | None, str | None]:
|
||||
"""Ask for VLLM configuration. """
|
||||
import os
|
||||
default_base = os.environ.get("VLLM_API_BASE", "https://localhost:8000/v1")
|
||||
api_base = questionary.text("Enter VLLM API URL:", default=default_base, style=questionary.Style(
|
||||
[
|
||||
("text", "fg:green"),
|
||||
("highlighted", "noinherit"),
|
||||
])).ask()
|
||||
default_api_key = os.environ.get("VLLM_API_KEY", "")
|
||||
api_key = questionary.text("Enter VLLM API Key:", default=default_api_key, style=questionary.Style(
|
||||
[
|
||||
("text", "fg:green"),
|
||||
("highlighted", "noinherit"),
|
||||
])).ask()
|
||||
return api_base.strip() if api_base else None, api_key.strip() if api_key else None
|
||||
|
|
|
|||
|
|
@ -78,16 +78,21 @@ class TradingAgentsGraph:
|
|||
if self.callbacks:
|
||||
llm_kwargs["callbacks"] = self.callbacks
|
||||
|
||||
provider = self.config["llm_provider"].lower()
|
||||
if provider == "vllm":
|
||||
base_url = self.config.get("vllm_api_base") or self.config["backend_url"]
|
||||
else:
|
||||
base_url = self.config["backend_url"]
|
||||
deep_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["deep_think_llm"],
|
||||
base_url=self.config.get("backend_url"),
|
||||
base_url=base_url,
|
||||
**llm_kwargs,
|
||||
)
|
||||
quick_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["quick_think_llm"],
|
||||
base_url=self.config.get("backend_url"),
|
||||
base_url=base_url,
|
||||
**llm_kwargs,
|
||||
)
|
||||
|
||||
|
|
@ -147,9 +152,9 @@ class TradingAgentsGraph:
|
|||
|
||||
elif provider == "vllm":
|
||||
# vllm specific settings
|
||||
api_base = self.config.get("vllm_api_base")
|
||||
if api_base:
|
||||
kwargs["base_url"] = api_base
|
||||
api_key = self.config.get("vllm_api_key")
|
||||
if api_key:
|
||||
kwargs["api_key"] = api_key
|
||||
# Add any vllm-specific parameters here if needed
|
||||
|
||||
return kwargs
|
||||
|
|
|
|||
|
|
@ -69,11 +69,11 @@ VALID_MODELS = {
|
|||
def validate_model(provider: str, model: str) -> bool:
|
||||
"""Check if model name is valid for the given provider.
|
||||
|
||||
For ollama, openrouter - any model is accepted.
|
||||
For ollama, openrouter, vllm - any model is accepted.
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower in ("ollama", "openrouter"):
|
||||
if provider_lower in ("ollama", "openrouter", "vllm"):
|
||||
return True
|
||||
|
||||
if provider_lower not in VALID_MODELS:
|
||||
|
|
|
|||
Loading…
Reference in New Issue