Merge c1ec37d830 into f362a160c3
This commit is contained in:
commit
d539bd5a2c
|
|
@ -4,3 +4,7 @@ GOOGLE_API_KEY=
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
XAI_API_KEY=
|
XAI_API_KEY=
|
||||||
OPENROUTER_API_KEY=
|
OPENROUTER_API_KEY=
|
||||||
|
|
||||||
|
# vllm Configuration (optional)
|
||||||
|
VLLM_API_BASE=http://localhost:8000/v1
|
||||||
|
VLLM_API_KEY=
|
||||||
|
|
|
||||||
18
cli/main.py
18
cli/main.py
|
|
@ -556,6 +556,8 @@ def get_user_selections():
|
||||||
# Step 7: Provider-specific thinking configuration
|
# Step 7: Provider-specific thinking configuration
|
||||||
thinking_level = None
|
thinking_level = None
|
||||||
reasoning_effort = None
|
reasoning_effort = None
|
||||||
|
vllm_api_base = None
|
||||||
|
vllm_api_key = None
|
||||||
|
|
||||||
provider_lower = selected_llm_provider.lower()
|
provider_lower = selected_llm_provider.lower()
|
||||||
if provider_lower == "google":
|
if provider_lower == "google":
|
||||||
|
|
@ -574,6 +576,14 @@ def get_user_selections():
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
reasoning_effort = ask_openai_reasoning_effort()
|
reasoning_effort = ask_openai_reasoning_effort()
|
||||||
|
elif provider_lower == "vllm":
|
||||||
|
console.print(
|
||||||
|
create_question_box(
|
||||||
|
"Step 7 : vLLM Configuration",
|
||||||
|
"Configure vLLM API configuration"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
vllm_api_base, vllm_api_key = ask_vllm_config()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"ticker": selected_ticker,
|
"ticker": selected_ticker,
|
||||||
|
|
@ -586,6 +596,8 @@ def get_user_selections():
|
||||||
"deep_thinker": selected_deep_thinker,
|
"deep_thinker": selected_deep_thinker,
|
||||||
"google_thinking_level": thinking_level,
|
"google_thinking_level": thinking_level,
|
||||||
"openai_reasoning_effort": reasoning_effort,
|
"openai_reasoning_effort": reasoning_effort,
|
||||||
|
"vllm_api_base": vllm_api_base,
|
||||||
|
"vllm_api_key": vllm_api_key,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -911,7 +923,11 @@ def run_analysis():
|
||||||
# Provider-specific thinking configuration
|
# Provider-specific thinking configuration
|
||||||
config["google_thinking_level"] = selections.get("google_thinking_level")
|
config["google_thinking_level"] = selections.get("google_thinking_level")
|
||||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||||
|
if selections["llm_provider"] == "vllm":
|
||||||
|
if selections.get("vllm_api_base"):
|
||||||
|
config["backend_url"] = selections["vllm_api_base"]
|
||||||
|
config["vllm_api_base"] = selections.get("vllm_api_base")
|
||||||
|
config["vllm_api_key"] = selections.get("vllm_api_key")
|
||||||
# Create stats callback handler for tracking LLM/tool calls
|
# Create stats callback handler for tracking LLM/tool calls
|
||||||
stats_handler = StatsCallbackHandler()
|
stats_handler = StatsCallbackHandler()
|
||||||
|
|
||||||
|
|
|
||||||
21
cli/utils.py
21
cli/utils.py
|
|
@ -164,6 +164,7 @@ def select_shallow_thinking_agent(provider) -> str:
|
||||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||||
],
|
],
|
||||||
|
"vllm": [("Qwen/Qwen3.5-2B", "Qwen/Qwen3.5-2B")]
|
||||||
}
|
}
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
|
|
@ -231,6 +232,7 @@ def select_deep_thinking_agent(provider) -> str:
|
||||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||||
],
|
],
|
||||||
|
"vllm": [("Qwen/Qwen3.5-2B", "Qwen/Qwen3.5-2B")]
|
||||||
}
|
}
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
|
|
@ -265,6 +267,7 @@ def select_llm_provider() -> tuple[str, str]:
|
||||||
("xAI", "https://api.x.ai/v1"),
|
("xAI", "https://api.x.ai/v1"),
|
||||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||||
("Ollama", "http://localhost:11434/v1"),
|
("Ollama", "http://localhost:11434/v1"),
|
||||||
|
("vLLM", "http://localhost:8000/v1")
|
||||||
]
|
]
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
|
|
@ -329,3 +332,21 @@ def ask_gemini_thinking_config() -> str | None:
|
||||||
("pointer", "fg:green noinherit"),
|
("pointer", "fg:green noinherit"),
|
||||||
]),
|
]),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def ask_vllm_config() -> tuple[str | None, str | None]:
|
||||||
|
"""Ask for VLLM configuration. """
|
||||||
|
import os
|
||||||
|
default_base = os.environ.get("VLLM_API_BASE", "https://localhost:8000/v1")
|
||||||
|
api_base = questionary.text("Enter VLLM API URL:", default=default_base, style=questionary.Style(
|
||||||
|
[
|
||||||
|
("text", "fg:green"),
|
||||||
|
("highlighted", "noinherit"),
|
||||||
|
])).ask()
|
||||||
|
default_api_key = os.environ.get("VLLM_API_KEY", "")
|
||||||
|
api_key = questionary.text("Enter VLLM API Key:", default=default_api_key, style=questionary.Style(
|
||||||
|
[
|
||||||
|
("text", "fg:green"),
|
||||||
|
("highlighted", "noinherit"),
|
||||||
|
])).ask()
|
||||||
|
return api_base.strip() if api_base else None, api_key.strip() if api_key else None
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ DEFAULT_CONFIG = {
|
||||||
# Provider-specific thinking configuration
|
# Provider-specific thinking configuration
|
||||||
"google_thinking_level": None, # "high", "minimal", etc.
|
"google_thinking_level": None, # "high", "minimal", etc.
|
||||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||||
|
"vllm_api_base": None, # vllm API base URL (defaults to http://localhost:8000/v1)
|
||||||
# Debate and discussion settings
|
# Debate and discussion settings
|
||||||
"max_debate_rounds": 1,
|
"max_debate_rounds": 1,
|
||||||
"max_risk_discuss_rounds": 1,
|
"max_risk_discuss_rounds": 1,
|
||||||
|
|
|
||||||
|
|
@ -78,16 +78,21 @@ class TradingAgentsGraph:
|
||||||
if self.callbacks:
|
if self.callbacks:
|
||||||
llm_kwargs["callbacks"] = self.callbacks
|
llm_kwargs["callbacks"] = self.callbacks
|
||||||
|
|
||||||
|
provider = self.config["llm_provider"].lower()
|
||||||
|
if provider == "vllm":
|
||||||
|
base_url = self.config.get("vllm_api_base") or self.config["backend_url"]
|
||||||
|
else:
|
||||||
|
base_url = self.config["backend_url"]
|
||||||
deep_client = create_llm_client(
|
deep_client = create_llm_client(
|
||||||
provider=self.config["llm_provider"],
|
provider=self.config["llm_provider"],
|
||||||
model=self.config["deep_think_llm"],
|
model=self.config["deep_think_llm"],
|
||||||
base_url=self.config.get("backend_url"),
|
base_url=base_url,
|
||||||
**llm_kwargs,
|
**llm_kwargs,
|
||||||
)
|
)
|
||||||
quick_client = create_llm_client(
|
quick_client = create_llm_client(
|
||||||
provider=self.config["llm_provider"],
|
provider=self.config["llm_provider"],
|
||||||
model=self.config["quick_think_llm"],
|
model=self.config["quick_think_llm"],
|
||||||
base_url=self.config.get("backend_url"),
|
base_url=base_url,
|
||||||
**llm_kwargs,
|
**llm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -148,6 +153,13 @@ class TradingAgentsGraph:
|
||||||
if reasoning_effort:
|
if reasoning_effort:
|
||||||
kwargs["reasoning_effort"] = reasoning_effort
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
|
|
||||||
|
elif provider == "vllm":
|
||||||
|
# vllm specific settings
|
||||||
|
api_key = self.config.get("vllm_api_key")
|
||||||
|
if api_key:
|
||||||
|
kwargs["api_key"] = api_key
|
||||||
|
# Add any vllm-specific parameters here if needed
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from .base_client import BaseLLMClient
|
||||||
from .openai_client import OpenAIClient
|
from .openai_client import OpenAIClient
|
||||||
from .anthropic_client import AnthropicClient
|
from .anthropic_client import AnthropicClient
|
||||||
from .google_client import GoogleClient
|
from .google_client import GoogleClient
|
||||||
|
from .vllm_client import VLLMClient
|
||||||
|
|
||||||
|
|
||||||
def create_llm_client(
|
def create_llm_client(
|
||||||
|
|
@ -15,7 +16,7 @@ def create_llm_client(
|
||||||
"""Create an LLM client for the specified provider.
|
"""Create an LLM client for the specified provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm)
|
||||||
model: Model name/identifier
|
model: Model name/identifier
|
||||||
base_url: Optional base URL for API endpoint
|
base_url: Optional base URL for API endpoint
|
||||||
**kwargs: Additional provider-specific arguments
|
**kwargs: Additional provider-specific arguments
|
||||||
|
|
@ -46,4 +47,7 @@ def create_llm_client(
|
||||||
if provider_lower == "google":
|
if provider_lower == "google":
|
||||||
return GoogleClient(model, base_url, **kwargs)
|
return GoogleClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
if provider_lower == "vllm":
|
||||||
|
return VLLMClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -54,11 +54,11 @@ VALID_MODELS = {
|
||||||
def validate_model(provider: str, model: str) -> bool:
|
def validate_model(provider: str, model: str) -> bool:
|
||||||
"""Check if model name is valid for the given provider.
|
"""Check if model name is valid for the given provider.
|
||||||
|
|
||||||
For ollama, openrouter - any model is accepted.
|
For ollama, openrouter, vllm - any model is accepted.
|
||||||
"""
|
"""
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
if provider_lower in ("ollama", "openrouter"):
|
if provider_lower in ("ollama", "openrouter", "vllm"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if provider_lower not in VALID_MODELS:
|
if provider_lower not in VALID_MODELS:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
import os
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMClient(BaseLLMClient):
|
||||||
|
"""Client for vllm models. Uses OpenAI-compatible API."""
|
||||||
|
|
||||||
|
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
"""Return configured ChatOpenAI instance for vllm."""
|
||||||
|
llm_kwargs = {
|
||||||
|
"model": self.model,
|
||||||
|
"base_url": self.base_url or os.environ.get("VLLM_API_BASE", "http://localhost:8000/v1"),
|
||||||
|
"api_key": self.kwargs.get("api_key") or os.environ.get("VLLM_API_KEY") or "vllm",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add supported parameters
|
||||||
|
for key in ("temperature", "top_p", "max_tokens", "timeout", "max_retries", "callbacks"):
|
||||||
|
if key in self.kwargs:
|
||||||
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
return ChatOpenAI(**llm_kwargs)
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
"""Validate model name is provided."""
|
||||||
|
return bool(self.model and self.model.strip())
|
||||||
Loading…
Reference in New Issue