feat(cli): allow custom OpenRouter model IDs
This commit is contained in:
parent
f362a160c3
commit
4bcca59ee8
58
cli/utils.py
58
cli/utils.py
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Callable, List, Optional, Tuple, Dict
|
||||
|
||||
import questionary
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
|
|
@ -7,6 +8,8 @@ from cli.models import AnalystType
|
|||
|
||||
console = Console()
|
||||
|
||||
CUSTOM_OPENROUTER_MODEL = "__custom_openrouter_model__"
|
||||
|
||||
ANALYST_ORDER = [
|
||||
("Market Analyst", AnalystType.MARKET),
|
||||
("Social Media Analyst", AnalystType.SOCIAL),
|
||||
|
|
@ -68,6 +71,43 @@ def get_analysis_date() -> str:
|
|||
return date.strip()
|
||||
|
||||
|
||||
def prompt_custom_openrouter_model(model_role: str) -> Optional[str]:
|
||||
"""Prompt for an OpenRouter model id when the built-in list is insufficient."""
|
||||
return questionary.text(
|
||||
f"Enter the OpenRouter model ID for the {model_role} model (e.g. minimax/minimax-m2.1):",
|
||||
validate=lambda x: len(x.strip()) > 0
|
||||
or "Please enter a valid OpenRouter model ID.",
|
||||
style=questionary.Style(
|
||||
[
|
||||
("text", "fg:magenta"),
|
||||
("highlighted", "noinherit"),
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
|
||||
def resolve_model_choice(
|
||||
provider: str,
|
||||
choice: Optional[str],
|
||||
model_role: str,
|
||||
prompt_fn: Optional[Callable[[str], Optional[str]]] = None,
|
||||
) -> str:
|
||||
"""Resolve built-in and custom model selections into a concrete model id."""
|
||||
if choice is None:
|
||||
console.print(f"\n[red]No {model_role.lower()} llm engine selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
if provider.lower() == "openrouter" and choice == CUSTOM_OPENROUTER_MODEL:
|
||||
prompt_fn = prompt_fn or prompt_custom_openrouter_model
|
||||
custom_model = prompt_fn(model_role)
|
||||
if not custom_model or not custom_model.strip():
|
||||
console.print("\n[red]No OpenRouter model ID provided. Exiting...[/red]")
|
||||
exit(1)
|
||||
return custom_model.strip()
|
||||
|
||||
return choice
|
||||
|
||||
|
||||
def select_analysts() -> List[AnalystType]:
|
||||
"""Select analysts using an interactive checkbox."""
|
||||
choices = questionary.checkbox(
|
||||
|
|
@ -158,6 +198,7 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
"openrouter": [
|
||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||
("Custom OpenRouter model ID", CUSTOM_OPENROUTER_MODEL),
|
||||
],
|
||||
"ollama": [
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
|
|
@ -182,13 +223,7 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
),
|
||||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print(
|
||||
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
return choice
|
||||
return resolve_model_choice(provider, choice, "Quick-Thinking", prompt_custom_openrouter_model)
|
||||
|
||||
|
||||
def select_deep_thinking_agent(provider) -> str:
|
||||
|
|
@ -225,6 +260,7 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
"openrouter": [
|
||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||
("Custom OpenRouter model ID", CUSTOM_OPENROUTER_MODEL),
|
||||
],
|
||||
"ollama": [
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
|
|
@ -249,11 +285,7 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
),
|
||||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
return choice
|
||||
return resolve_model_choice(provider, choice, "Deep-Thinking", prompt_custom_openrouter_model)
|
||||
|
||||
def select_llm_provider() -> tuple[str, str]:
|
||||
"""Select the OpenAI api url using interactive selection."""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
import unittest
|
||||
|
||||
from cli.utils import CUSTOM_OPENROUTER_MODEL, resolve_model_choice
|
||||
|
||||
|
||||
class OpenRouterModelSelectionTests(unittest.TestCase):
|
||||
def test_builtin_model_is_returned_unchanged(self):
|
||||
self.assertEqual(
|
||||
resolve_model_choice("openrouter", "z-ai/glm-4.5-air:free", "Quick-Thinking"),
|
||||
"z-ai/glm-4.5-air:free",
|
||||
)
|
||||
|
||||
def test_custom_model_prompt_value_is_trimmed(self):
|
||||
chosen = resolve_model_choice(
|
||||
"openrouter",
|
||||
CUSTOM_OPENROUTER_MODEL,
|
||||
"Deep-Thinking",
|
||||
prompt_fn=lambda _: " minimax/minimax-m2.1 ",
|
||||
)
|
||||
self.assertEqual(chosen, "minimax/minimax-m2.1")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue