TradingAgents/tradingagents/llm_clients/validators.py

98 lines
2.5 KiB
Python

"""Model name validators for each provider.
Only validates model names - does NOT enforce limits.
Let LLM providers use their own defaults for unspecified params.
"""
import warnings
VALID_MODELS = {
"openai": [
# GPT-5 series
"gpt-5.4-pro",
"gpt-5.4",
"gpt-5.2",
"gpt-5.1",
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
# GPT-4.1 series
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
],
"anthropic": [
# Claude 4.6 series (latest)
"claude-opus-4-6",
"claude-sonnet-4-6",
# Claude 4.5 series
"claude-opus-4-5",
"claude-sonnet-4-5",
"claude-haiku-4-5",
],
"google": [
# Gemini 3.1 series (preview)
"gemini-3.1-pro-preview",
"gemini-3.1-flash-lite-preview",
# Gemini 3 series (preview)
"gemini-3-flash-preview",
# Gemini 2.5 series
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-2.5-flash-lite",
],
"xai": [
# Grok 4.1 series
"grok-4-1-fast-reasoning",
"grok-4-1-fast-non-reasoning",
# Grok 4 series
"grok-4-0709",
"grok-4-fast-reasoning",
"grok-4-fast-non-reasoning",
],
}
PROVIDER_DISPLAY_NAMES = {
"openai": "OpenAI",
"anthropic": "Anthropic",
"google": "Google",
"xai": "xAI",
"ollama": "Ollama",
"openrouter": "OpenRouter",
}
PERMISSIVE_PROVIDERS = {"ollama", "openrouter"}
KNOWN_PROVIDERS = set(VALID_MODELS) | PERMISSIVE_PROVIDERS
def validate_model(provider: str, model: str) -> bool:
"""Check if model name is valid for the given provider.
For ollama, openrouter - any model is accepted.
"""
provider_lower = provider.lower()
if provider_lower in PERMISSIVE_PROVIDERS:
return True
if provider_lower not in KNOWN_PROVIDERS:
return False
return model in VALID_MODELS[provider_lower]
def warn_if_unknown_model(provider: str, model: str) -> bool:
"""Warn for unknown models while allowing execution to continue."""
is_valid = validate_model(provider, model)
if not is_valid:
provider_name = PROVIDER_DISPLAY_NAMES.get(provider.lower(), provider)
warnings.warn(
(
f"Unknown {provider_name} model '{model}'. "
"Continuing without blocking because providers may add models before this list is updated."
),
UserWarning,
stacklevel=2,
)
return is_valid