sync model validation with cli catalog
This commit is contained in:
parent
f362a160c3
commit
8793336dad
81
cli/utils.py
81
cli/utils.py
|
|
@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Dict
|
|||
from rich.console import Console
|
||||
|
||||
from cli.models import AnalystType
|
||||
from tradingagents.llm_clients.model_catalog import get_model_options
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -129,48 +130,11 @@ def select_research_depth() -> int:
|
|||
def select_shallow_thinking_agent(provider) -> str:
|
||||
"""Select shallow thinking llm engine using an interactive selection."""
|
||||
|
||||
# Define shallow thinking llm engine options with their corresponding model names
|
||||
# Ordering: medium → light → heavy (balanced first for quick tasks)
|
||||
# Within same tier, newer models first
|
||||
SHALLOW_AGENT_OPTIONS = {
|
||||
"openai": [
|
||||
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
|
||||
("GPT-5 Nano - High-throughput, simple tasks", "gpt-5-nano"),
|
||||
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
||||
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
|
||||
],
|
||||
"anthropic": [
|
||||
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||
("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"),
|
||||
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
||||
],
|
||||
"google": [
|
||||
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||
("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"),
|
||||
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
|
||||
],
|
||||
"xai": [
|
||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
|
||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||
],
|
||||
"openrouter": [
|
||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||
],
|
||||
"ollama": [
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
],
|
||||
}
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Quick-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
|
||||
for display, value in get_model_options(provider, "quick")
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
@ -194,50 +158,11 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
def select_deep_thinking_agent(provider) -> str:
|
||||
"""Select deep thinking llm engine using an interactive selection."""
|
||||
|
||||
# Define deep thinking llm engine options with their corresponding model names
|
||||
# Ordering: heavy → medium → light (most capable first for deep tasks)
|
||||
# Within same tier, newer models first
|
||||
DEEP_AGENT_OPTIONS = {
|
||||
"openai": [
|
||||
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
||||
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"),
|
||||
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
|
||||
("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"),
|
||||
],
|
||||
"anthropic": [
|
||||
("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"),
|
||||
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
|
||||
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
||||
],
|
||||
"google": [
|
||||
("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"),
|
||||
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
|
||||
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||
],
|
||||
"xai": [
|
||||
("Grok 4 - Flagship model", "grok-4-0709"),
|
||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
|
||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||
],
|
||||
"openrouter": [
|
||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||
],
|
||||
"ollama": [
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
],
|
||||
}
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Deep-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
|
||||
for display, value in get_model_options(provider, "deep")
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,52 @@
|
|||
import unittest
|
||||
import warnings
|
||||
|
||||
from tradingagents.llm_clients.base_client import BaseLLMClient
|
||||
from tradingagents.llm_clients.model_catalog import get_known_models
|
||||
from tradingagents.llm_clients.validators import validate_model
|
||||
|
||||
|
||||
class DummyLLMClient(BaseLLMClient):
|
||||
def __init__(self, provider: str, model: str):
|
||||
self.provider = provider
|
||||
super().__init__(model)
|
||||
|
||||
def get_llm(self):
|
||||
self.warn_if_unknown_model()
|
||||
return object()
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
return validate_model(self.provider, self.model)
|
||||
|
||||
|
||||
class ModelValidationTests(unittest.TestCase):
|
||||
def test_cli_catalog_models_are_all_validator_approved(self):
|
||||
for provider, models in get_known_models().items():
|
||||
if provider in ("ollama", "openrouter"):
|
||||
continue
|
||||
|
||||
for model in models:
|
||||
with self.subTest(provider=provider, model=model):
|
||||
self.assertTrue(validate_model(provider, model))
|
||||
|
||||
def test_unknown_model_emits_warning_for_strict_provider(self):
|
||||
client = DummyLLMClient("openai", "not-a-real-openai-model")
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
client.get_llm()
|
||||
|
||||
self.assertEqual(len(caught), 1)
|
||||
self.assertIn("not-a-real-openai-model", str(caught[0].message))
|
||||
self.assertIn("openai", str(caught[0].message))
|
||||
|
||||
def test_openrouter_and_ollama_accept_custom_models_without_warning(self):
|
||||
for provider in ("openrouter", "ollama"):
|
||||
client = DummyLLMClient(provider, "custom-model-name")
|
||||
|
||||
with self.subTest(provider=provider):
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
client.get_llm()
|
||||
|
||||
self.assertEqual(caught, [])
|
||||
|
|
@ -14,6 +14,7 @@ class AnthropicClient(BaseLLMClient):
|
|||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatAnthropic instance."""
|
||||
self.warn_if_unknown_model()
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
for key in ("timeout", "max_retries", "api_key", "max_tokens", "callbacks", "http_client", "http_async_client"):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
|
||||
|
||||
class BaseLLMClient(ABC):
|
||||
|
|
@ -10,6 +11,27 @@ class BaseLLMClient(ABC):
|
|||
self.base_url = base_url
|
||||
self.kwargs = kwargs
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
"""Return the provider name used in warning messages."""
|
||||
provider = getattr(self, "provider", None)
|
||||
if provider:
|
||||
return str(provider)
|
||||
return self.__class__.__name__.removesuffix("Client").lower()
|
||||
|
||||
def warn_if_unknown_model(self) -> None:
|
||||
"""Warn when the model is outside the known list for the provider."""
|
||||
if self.validate_model():
|
||||
return
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
f"Model '{self.model}' is not in the known model list for "
|
||||
f"provider '{self.get_provider_name()}'. Continuing anyway."
|
||||
),
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_llm(self) -> Any:
|
||||
"""Return the configured LLM instance."""
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ class GoogleClient(BaseLLMClient):
|
|||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatGoogleGenerativeAI instance."""
|
||||
self.warn_if_unknown_model()
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
for key in ("timeout", "max_retries", "google_api_key", "callbacks", "http_client", "http_async_client"):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,106 @@
|
|||
"""Shared model catalog for CLI selections and validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
ModelOption = Tuple[str, str]
|
||||
ProviderModeOptions = Dict[str, List[ModelOption]]
|
||||
|
||||
|
||||
MODEL_OPTIONS: ProviderModeOptions = {
|
||||
"openai": {
|
||||
"quick": [
|
||||
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
|
||||
("GPT-5 Nano - High-throughput, simple tasks", "gpt-5-nano"),
|
||||
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
||||
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
|
||||
],
|
||||
"deep": [
|
||||
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
||||
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"),
|
||||
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
|
||||
("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"),
|
||||
],
|
||||
},
|
||||
"anthropic": {
|
||||
"quick": [
|
||||
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||
("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"),
|
||||
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
||||
],
|
||||
"deep": [
|
||||
("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"),
|
||||
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
|
||||
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
||||
],
|
||||
},
|
||||
"google": {
|
||||
"quick": [
|
||||
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||
("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"),
|
||||
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
|
||||
],
|
||||
"deep": [
|
||||
("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"),
|
||||
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
|
||||
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||
],
|
||||
},
|
||||
"xai": {
|
||||
"quick": [
|
||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
|
||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||
],
|
||||
"deep": [
|
||||
("Grok 4 - Flagship model", "grok-4-0709"),
|
||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
|
||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||
],
|
||||
},
|
||||
"openrouter": {
|
||||
"quick": [
|
||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||
],
|
||||
"deep": [
|
||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||
],
|
||||
},
|
||||
"ollama": {
|
||||
"quick": [
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
],
|
||||
"deep": [
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_model_options(provider: str, mode: str) -> List[ModelOption]:
|
||||
"""Return shared model options for a provider and selection mode."""
|
||||
return MODEL_OPTIONS[provider.lower()][mode]
|
||||
|
||||
|
||||
def get_known_models() -> Dict[str, List[str]]:
|
||||
"""Build known model names from the shared CLI catalog."""
|
||||
known_models: Dict[str, List[str]] = {}
|
||||
for provider, mode_options in MODEL_OPTIONS.items():
|
||||
model_names = {
|
||||
value
|
||||
for options in mode_options.values()
|
||||
for _, value in options
|
||||
}
|
||||
known_models[provider] = sorted(model_names)
|
||||
return known_models
|
||||
|
|
@ -41,6 +41,7 @@ class OpenAIClient(BaseLLMClient):
|
|||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatOpenAI instance."""
|
||||
self.warn_if_unknown_model()
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
if self.provider == "xai":
|
||||
|
|
|
|||
|
|
@ -1,53 +1,12 @@
|
|||
"""Model name validators for each provider.
|
||||
"""Model name validators for each provider."""
|
||||
|
||||
from .model_catalog import get_known_models
|
||||
|
||||
Only validates model names - does NOT enforce limits.
|
||||
Let LLM providers use their own defaults for unspecified params.
|
||||
"""
|
||||
|
||||
VALID_MODELS = {
|
||||
"openai": [
|
||||
# GPT-5 series
|
||||
"gpt-5.4-pro",
|
||||
"gpt-5.4",
|
||||
"gpt-5.2",
|
||||
"gpt-5.1",
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"gpt-5-nano",
|
||||
# GPT-4.1 series
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
],
|
||||
"anthropic": [
|
||||
# Claude 4.6 series (latest)
|
||||
"claude-opus-4-6",
|
||||
"claude-sonnet-4-6",
|
||||
# Claude 4.5 series
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
],
|
||||
"google": [
|
||||
# Gemini 3.1 series (preview)
|
||||
"gemini-3.1-pro-preview",
|
||||
"gemini-3.1-flash-lite-preview",
|
||||
# Gemini 3 series (preview)
|
||||
"gemini-3-flash-preview",
|
||||
# Gemini 2.5 series
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
],
|
||||
"xai": [
|
||||
# Grok 4.1 series
|
||||
"grok-4-1-fast-reasoning",
|
||||
"grok-4-1-fast-non-reasoning",
|
||||
# Grok 4 series
|
||||
"grok-4-0709",
|
||||
"grok-4-fast-reasoning",
|
||||
"grok-4-fast-non-reasoning",
|
||||
],
|
||||
provider: models
|
||||
for provider, models in get_known_models().items()
|
||||
if provider not in ("ollama", "openrouter")
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue