feat(cli): allow custom OpenRouter model IDs
This commit is contained in:
parent
f362a160c3
commit
4bcca59ee8
58
cli/utils.py
58
cli/utils.py
|
|
@ -1,5 +1,6 @@
|
||||||
|
from typing import Callable, List, Optional, Tuple, Dict
|
||||||
|
|
||||||
import questionary
|
import questionary
|
||||||
from typing import List, Optional, Tuple, Dict
|
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
|
|
@ -7,6 +8,8 @@ from cli.models import AnalystType
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
CUSTOM_OPENROUTER_MODEL = "__custom_openrouter_model__"
|
||||||
|
|
||||||
ANALYST_ORDER = [
|
ANALYST_ORDER = [
|
||||||
("Market Analyst", AnalystType.MARKET),
|
("Market Analyst", AnalystType.MARKET),
|
||||||
("Social Media Analyst", AnalystType.SOCIAL),
|
("Social Media Analyst", AnalystType.SOCIAL),
|
||||||
|
|
@ -68,6 +71,43 @@ def get_analysis_date() -> str:
|
||||||
return date.strip()
|
return date.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_custom_openrouter_model(model_role: str) -> Optional[str]:
|
||||||
|
"""Prompt for an OpenRouter model id when the built-in list is insufficient."""
|
||||||
|
return questionary.text(
|
||||||
|
f"Enter the OpenRouter model ID for the {model_role} model (e.g. minimax/minimax-m2.1):",
|
||||||
|
validate=lambda x: len(x.strip()) > 0
|
||||||
|
or "Please enter a valid OpenRouter model ID.",
|
||||||
|
style=questionary.Style(
|
||||||
|
[
|
||||||
|
("text", "fg:magenta"),
|
||||||
|
("highlighted", "noinherit"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model_choice(
|
||||||
|
provider: str,
|
||||||
|
choice: Optional[str],
|
||||||
|
model_role: str,
|
||||||
|
prompt_fn: Optional[Callable[[str], Optional[str]]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Resolve built-in and custom model selections into a concrete model id."""
|
||||||
|
if choice is None:
|
||||||
|
console.print(f"\n[red]No {model_role.lower()} llm engine selected. Exiting...[/red]")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if provider.lower() == "openrouter" and choice == CUSTOM_OPENROUTER_MODEL:
|
||||||
|
prompt_fn = prompt_fn or prompt_custom_openrouter_model
|
||||||
|
custom_model = prompt_fn(model_role)
|
||||||
|
if not custom_model or not custom_model.strip():
|
||||||
|
console.print("\n[red]No OpenRouter model ID provided. Exiting...[/red]")
|
||||||
|
exit(1)
|
||||||
|
return custom_model.strip()
|
||||||
|
|
||||||
|
return choice
|
||||||
|
|
||||||
|
|
||||||
def select_analysts() -> List[AnalystType]:
|
def select_analysts() -> List[AnalystType]:
|
||||||
"""Select analysts using an interactive checkbox."""
|
"""Select analysts using an interactive checkbox."""
|
||||||
choices = questionary.checkbox(
|
choices = questionary.checkbox(
|
||||||
|
|
@ -158,6 +198,7 @@ def select_shallow_thinking_agent(provider) -> str:
|
||||||
"openrouter": [
|
"openrouter": [
|
||||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||||
|
("Custom OpenRouter model ID", CUSTOM_OPENROUTER_MODEL),
|
||||||
],
|
],
|
||||||
"ollama": [
|
"ollama": [
|
||||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||||
|
|
@ -182,13 +223,7 @@ def select_shallow_thinking_agent(provider) -> str:
|
||||||
),
|
),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
return resolve_model_choice(provider, choice, "Quick-Thinking", prompt_custom_openrouter_model)
|
||||||
console.print(
|
|
||||||
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
|
|
||||||
)
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
return choice
|
|
||||||
|
|
||||||
|
|
||||||
def select_deep_thinking_agent(provider) -> str:
|
def select_deep_thinking_agent(provider) -> str:
|
||||||
|
|
@ -225,6 +260,7 @@ def select_deep_thinking_agent(provider) -> str:
|
||||||
"openrouter": [
|
"openrouter": [
|
||||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||||
|
("Custom OpenRouter model ID", CUSTOM_OPENROUTER_MODEL),
|
||||||
],
|
],
|
||||||
"ollama": [
|
"ollama": [
|
||||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||||
|
|
@ -249,11 +285,7 @@ def select_deep_thinking_agent(provider) -> str:
|
||||||
),
|
),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
return resolve_model_choice(provider, choice, "Deep-Thinking", prompt_custom_openrouter_model)
|
||||||
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
return choice
|
|
||||||
|
|
||||||
def select_llm_provider() -> tuple[str, str]:
|
def select_llm_provider() -> tuple[str, str]:
|
||||||
"""Select the OpenAI api url using interactive selection."""
|
"""Select the OpenAI api url using interactive selection."""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from cli.utils import CUSTOM_OPENROUTER_MODEL, resolve_model_choice
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterModelSelectionTests(unittest.TestCase):
|
||||||
|
def test_builtin_model_is_returned_unchanged(self):
|
||||||
|
self.assertEqual(
|
||||||
|
resolve_model_choice("openrouter", "z-ai/glm-4.5-air:free", "Quick-Thinking"),
|
||||||
|
"z-ai/glm-4.5-air:free",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_custom_model_prompt_value_is_trimmed(self):
|
||||||
|
chosen = resolve_model_choice(
|
||||||
|
"openrouter",
|
||||||
|
CUSTOM_OPENROUTER_MODEL,
|
||||||
|
"Deep-Thinking",
|
||||||
|
prompt_fn=lambda _: " minimax/minimax-m2.1 ",
|
||||||
|
)
|
||||||
|
self.assertEqual(chosen, "minimax/minimax-m2.1")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue