From 8793336dade0709b95233969147feafc00dc9ff4 Mon Sep 17 00:00:00 2001 From: CadeYu Date: Wed, 25 Mar 2026 21:23:02 +0800 Subject: [PATCH] sync model validation with cli catalog --- cli/utils.py | 81 +------------ tests/test_model_validation.py | 52 +++++++++ tradingagents/llm_clients/anthropic_client.py | 1 + tradingagents/llm_clients/base_client.py | 22 ++++ tradingagents/llm_clients/google_client.py | 1 + tradingagents/llm_clients/model_catalog.py | 106 ++++++++++++++++++ tradingagents/llm_clients/openai_client.py | 1 + tradingagents/llm_clients/validators.py | 53 +-------- 8 files changed, 192 insertions(+), 125 deletions(-) create mode 100644 tests/test_model_validation.py create mode 100644 tradingagents/llm_clients/model_catalog.py diff --git a/cli/utils.py b/cli/utils.py index 5a8ec16c..9869fb4d 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Dict from rich.console import Console from cli.models import AnalystType +from tradingagents.llm_clients.model_catalog import get_model_options console = Console() @@ -129,48 +130,11 @@ def select_research_depth() -> int: def select_shallow_thinking_agent(provider) -> str: """Select shallow thinking llm engine using an interactive selection.""" - # Define shallow thinking llm engine options with their corresponding model names - # Ordering: medium → light → heavy (balanced first for quick tasks) - # Within same tier, newer models first - SHALLOW_AGENT_OPTIONS = { - "openai": [ - ("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"), - ("GPT-5 Nano - High-throughput, simple tasks", "gpt-5-nano"), - ("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"), - ("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"), - ], - "anthropic": [ - ("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"), - ("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"), - ("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"), - ], - "google": [ - ("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"), - ("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"), - ("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"), - ("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"), - ], - "xai": [ - ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), - ("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"), - ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), - ], - "openrouter": [ - ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), - ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), - ], - "ollama": [ - ("Qwen3:latest (8B, local)", "qwen3:latest"), - ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), - ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), - ], - } - choice = questionary.select( "Select Your [Quick-Thinking LLM Engine]:", choices=[ questionary.Choice(display, value=value) - for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()] + for display, value in get_model_options(provider, "quick") ], instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( @@ -194,50 +158,11 @@ def select_shallow_thinking_agent(provider) -> str: def select_deep_thinking_agent(provider) -> str: """Select deep thinking llm engine using an interactive selection.""" - # Define deep thinking llm engine options with their corresponding model names - # Ordering: heavy → medium → light (most capable first for deep tasks) - # Within same tier, newer models first - DEEP_AGENT_OPTIONS = { - "openai": [ - ("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"), - ("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"), - ("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"), - ("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"), - ], - "anthropic": [ - ("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"), - ("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"), - ("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"), - ("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"), - ], - "google": [ - ("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"), - ("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"), - ("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"), - ("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"), - ], - "xai": [ - ("Grok 4 - Flagship model", "grok-4-0709"), - ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), - ("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"), - ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), - ], - "openrouter": [ - ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), - ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), - ], - "ollama": [ - ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), - ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), - ("Qwen3:latest (8B, local)", "qwen3:latest"), - ], - } - choice = questionary.select( "Select Your [Deep-Thinking LLM Engine]:", choices=[ questionary.Choice(display, value=value) - for display, value in DEEP_AGENT_OPTIONS[provider.lower()] + for display, value in get_model_options(provider, "deep") ], instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( diff --git a/tests/test_model_validation.py b/tests/test_model_validation.py new file mode 100644 index 00000000..50f26318 --- /dev/null +++ b/tests/test_model_validation.py @@ -0,0 +1,52 @@ +import unittest +import warnings + +from tradingagents.llm_clients.base_client import BaseLLMClient +from tradingagents.llm_clients.model_catalog import get_known_models +from tradingagents.llm_clients.validators import validate_model + + +class DummyLLMClient(BaseLLMClient): + def __init__(self, provider: str, model: str): + self.provider = provider + super().__init__(model) + + def get_llm(self): + self.warn_if_unknown_model() + return object() + + def validate_model(self) -> bool: + return validate_model(self.provider, self.model) + + +class ModelValidationTests(unittest.TestCase): + def test_cli_catalog_models_are_all_validator_approved(self): + for provider, models in get_known_models().items(): + if provider in ("ollama", "openrouter"): + continue + + for model in models: + with self.subTest(provider=provider, model=model): + self.assertTrue(validate_model(provider, model)) + + def test_unknown_model_emits_warning_for_strict_provider(self): + client = DummyLLMClient("openai", "not-a-real-openai-model") + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + client.get_llm() + + self.assertEqual(len(caught), 1) + self.assertIn("not-a-real-openai-model", str(caught[0].message)) + self.assertIn("openai", str(caught[0].message)) + + def test_openrouter_and_ollama_accept_custom_models_without_warning(self): + for provider in ("openrouter", "ollama"): + client = DummyLLMClient(provider, "custom-model-name") + + with self.subTest(provider=provider): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + client.get_llm() + + self.assertEqual(caught, []) diff --git a/tradingagents/llm_clients/anthropic_client.py b/tradingagents/llm_clients/anthropic_client.py index 8539c752..939c7488 100644 --- a/tradingagents/llm_clients/anthropic_client.py +++ b/tradingagents/llm_clients/anthropic_client.py @@ -14,6 +14,7 @@ class AnthropicClient(BaseLLMClient): def get_llm(self) -> Any: """Return configured ChatAnthropic instance.""" + self.warn_if_unknown_model() llm_kwargs = {"model": self.model} for key in ("timeout", "max_retries", "api_key", "max_tokens", "callbacks", "http_client", "http_async_client"): diff --git a/tradingagents/llm_clients/base_client.py b/tradingagents/llm_clients/base_client.py index 43845575..81880856 100644 --- a/tradingagents/llm_clients/base_client.py +++ b/tradingagents/llm_clients/base_client.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Any, Optional +import warnings class BaseLLMClient(ABC): @@ -10,6 +11,27 @@ class BaseLLMClient(ABC): self.base_url = base_url self.kwargs = kwargs + def get_provider_name(self) -> str: + """Return the provider name used in warning messages.""" + provider = getattr(self, "provider", None) + if provider: + return str(provider) + return self.__class__.__name__.removesuffix("Client").lower() + + def warn_if_unknown_model(self) -> None: + """Warn when the model is outside the known list for the provider.""" + if self.validate_model(): + return + + warnings.warn( + ( + f"Model '{self.model}' is not in the known model list for " + f"provider '{self.get_provider_name()}'. Continuing anyway." + ), + RuntimeWarning, + stacklevel=2, + ) + @abstractmethod def get_llm(self) -> Any: """Return the configured LLM instance.""" diff --git a/tradingagents/llm_clients/google_client.py b/tradingagents/llm_clients/google_client.py index 3dd85e3f..557e2640 100644 --- a/tradingagents/llm_clients/google_client.py +++ b/tradingagents/llm_clients/google_client.py @@ -36,6 +36,7 @@ class GoogleClient(BaseLLMClient): def get_llm(self) -> Any: """Return configured ChatGoogleGenerativeAI instance.""" + self.warn_if_unknown_model() llm_kwargs = {"model": self.model} for key in ("timeout", "max_retries", "google_api_key", "callbacks", "http_client", "http_async_client"): diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py new file mode 100644 index 00000000..58447a89 --- /dev/null +++ b/tradingagents/llm_clients/model_catalog.py @@ -0,0 +1,106 @@ +"""Shared model catalog for CLI selections and validation.""" + +from __future__ import annotations + +from typing import Dict, List, Tuple + +ModelOption = Tuple[str, str] +ProviderModeOptions = Dict[str, List[ModelOption]] + + +MODEL_OPTIONS: ProviderModeOptions = { + "openai": { + "quick": [ + ("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"), + ("GPT-5 Nano - High-throughput, simple tasks", "gpt-5-nano"), + ("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"), + ("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"), + ], + "deep": [ + ("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"), + ("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"), + ("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"), + ("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"), + ], + }, + "anthropic": { + "quick": [ + ("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"), + ("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"), + ("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"), + ], + "deep": [ + ("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"), + ("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"), + ("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"), + ("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"), + ], + }, + "google": { + "quick": [ + ("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"), + ("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"), + ("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"), + ("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"), + ], + "deep": [ + ("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"), + ("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"), + ("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"), + ("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"), + ], + }, + "xai": { + "quick": [ + ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), + ("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"), + ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), + ], + "deep": [ + ("Grok 4 - Flagship model", "grok-4-0709"), + ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), + ("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"), + ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), + ], + }, + "openrouter": { + "quick": [ + ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), + ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), + ], + "deep": [ + ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), + ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), + ], + }, + "ollama": { + "quick": [ + ("Qwen3:latest (8B, local)", "qwen3:latest"), + ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), + ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), + ], + "deep": [ + ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), + ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), + ("Qwen3:latest (8B, local)", "qwen3:latest"), + ], + }, +} + + +def get_model_options(provider: str, mode: str) -> List[ModelOption]: + """Return shared model options for a provider and selection mode.""" + return MODEL_OPTIONS[provider.lower()][mode] + + +def get_known_models() -> Dict[str, List[str]]: + """Build known model names from the shared CLI catalog.""" + known_models: Dict[str, List[str]] = {} + for provider, mode_options in MODEL_OPTIONS.items(): + model_names = { + value + for options in mode_options.values() + for _, value in options + } + known_models[provider] = sorted(model_names) + return known_models diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 4605c1f9..0629d894 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -41,6 +41,7 @@ class OpenAIClient(BaseLLMClient): def get_llm(self) -> Any: """Return configured ChatOpenAI instance.""" + self.warn_if_unknown_model() llm_kwargs = {"model": self.model} if self.provider == "xai": diff --git a/tradingagents/llm_clients/validators.py b/tradingagents/llm_clients/validators.py index 1e2388b3..4e6d457b 100644 --- a/tradingagents/llm_clients/validators.py +++ b/tradingagents/llm_clients/validators.py @@ -1,53 +1,12 @@ -"""Model name validators for each provider. +"""Model name validators for each provider.""" + +from .model_catalog import get_known_models -Only validates model names - does NOT enforce limits. -Let LLM providers use their own defaults for unspecified params. -""" VALID_MODELS = { - "openai": [ - # GPT-5 series - "gpt-5.4-pro", - "gpt-5.4", - "gpt-5.2", - "gpt-5.1", - "gpt-5", - "gpt-5-mini", - "gpt-5-nano", - # GPT-4.1 series - "gpt-4.1", - "gpt-4.1-mini", - "gpt-4.1-nano", - ], - "anthropic": [ - # Claude 4.6 series (latest) - "claude-opus-4-6", - "claude-sonnet-4-6", - # Claude 4.5 series - "claude-opus-4-5", - "claude-sonnet-4-5", - "claude-haiku-4-5", - ], - "google": [ - # Gemini 3.1 series (preview) - "gemini-3.1-pro-preview", - "gemini-3.1-flash-lite-preview", - # Gemini 3 series (preview) - "gemini-3-flash-preview", - # Gemini 2.5 series - "gemini-2.5-pro", - "gemini-2.5-flash", - "gemini-2.5-flash-lite", - ], - "xai": [ - # Grok 4.1 series - "grok-4-1-fast-reasoning", - "grok-4-1-fast-non-reasoning", - # Grok 4 series - "grok-4-0709", - "grok-4-fast-reasoning", - "grok-4-fast-non-reasoning", - ], + provider: models + for provider, models in get_known_models().items() + if provider not in ("ollama", "openrouter") }