From b3e94c51345e868e79f282ed22e92e3879f98ab6 Mon Sep 17 00:00:00 2001 From: xjx <493337577@qq.com> Date: Mon, 9 Mar 2026 12:03:17 +0800 Subject: [PATCH] add vllm support --- cli/main.py | 61 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/cli/main.py b/cli/main.py index 1aff4058..8f545b3d 100644 --- a/cli/main.py +++ b/cli/main.py @@ -544,26 +544,17 @@ def get_user_selections(): ) selected_llm_provider, backend_url = select_llm_provider() - # Step 6: Thinking agents - console.print( - create_question_box( - "Step 6: Thinking Agents", "Select your thinking agents for analysis" - ) - ) - selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) - selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) - - # Step 7: Provider-specific thinking configuration + # Step 6: Provider-specific thinking configuration thinking_level = None reasoning_effort = None vllm_api_base = None vllm_api_key = None - + provider_lower = selected_llm_provider.lower() - if provider_lower == "google": + if provider_lower == "Google": console.print( create_question_box( - "Step 7: Thinking Mode", + "Step 6: Thinking Mode", "Configure Gemini thinking mode" ) ) @@ -571,7 +562,7 @@ def get_user_selections(): elif provider_lower == "openai": console.print( create_question_box( - "Step 7: Reasoning Effort", + "Step 6: Reasoning Effort", "Configure OpenAI reasoning effort level" ) ) @@ -579,11 +570,51 @@ def get_user_selections(): elif provider_lower == "vllm": console.print( create_question_box( - "Step 7 : vLLM Configuration", + "Step 6 : vLLM Configuration", "Configure vLLM API configuration" ) ) vllm_api_base, vllm_api_key = ask_vllm_config() + + # Step 7: Thinking agents + console.print( + create_question_box( + "Step 7: Thinking Agents", "Select your thinking agents for analysis" + ) + ) + + # If vLLM is selected, fetch model name from /v1/models endpoint + if provider_lower == "vllm" and vllm_api_base: + try: + import requests + base_url = vllm_api_base.rstrip('/') + headers = {} + if vllm_api_key: + headers["Authorization"] = f"Bearer {vllm_api_key}" + + response = requests.get(f"{base_url}/v1/models", headers=headers, timeout=5) + if response.status_code != 200: + raise Exception(f"Failed to fetch models from vLLM: status code {response.status_code}") + + data = response.json() + if 'data' not in data or len(data['data']) == 0: + raise Exception("No models returned from vLLM API") + + model_name = data['data'][0].get('id', 'unknown') + console.print(f"[green]✓ Fetched model from vLLM:[/green] {model_name}") + selected_shallow_thinker = model_name + selected_deep_thinker = model_name + + except requests.exceptions.ConnectionError: + console.print(f"[red]Error: Cannot connect to vLLM at {vllm_api_base}[/red]") + raise + except Exception as e: + console.print(f"[red]Error: {e}[/red]") + raise + else: + # For other providers, use normal selection + selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) + selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) return { "ticker": selected_ticker,