diff --git a/tradingagents/llm_clients/anthropic_client.py b/tradingagents/llm_clients/anthropic_client.py index 8539c752..0be4c3ee 100644 --- a/tradingagents/llm_clients/anthropic_client.py +++ b/tradingagents/llm_clients/anthropic_client.py @@ -3,7 +3,7 @@ from typing import Any, Optional from langchain_anthropic import ChatAnthropic from .base_client import BaseLLMClient -from .validators import validate_model +from .validators import validate_model, warn_if_unknown_model class AnthropicClient(BaseLLMClient): @@ -14,6 +14,7 @@ class AnthropicClient(BaseLLMClient): def get_llm(self) -> Any: """Return configured ChatAnthropic instance.""" + warn_if_unknown_model("anthropic", self.model) llm_kwargs = {"model": self.model} for key in ("timeout", "max_retries", "api_key", "max_tokens", "callbacks", "http_client", "http_async_client"): diff --git a/tradingagents/llm_clients/google_client.py b/tradingagents/llm_clients/google_client.py index 3dd85e3f..8ec5be64 100644 --- a/tradingagents/llm_clients/google_client.py +++ b/tradingagents/llm_clients/google_client.py @@ -3,7 +3,7 @@ from typing import Any, Optional from langchain_google_genai import ChatGoogleGenerativeAI from .base_client import BaseLLMClient -from .validators import validate_model +from .validators import validate_model, warn_if_unknown_model class NormalizedChatGoogleGenerativeAI(ChatGoogleGenerativeAI): @@ -36,6 +36,7 @@ class GoogleClient(BaseLLMClient): def get_llm(self) -> Any: """Return configured ChatGoogleGenerativeAI instance.""" + warn_if_unknown_model("google", self.model) llm_kwargs = {"model": self.model} for key in ("timeout", "max_retries", "google_api_key", "callbacks", "http_client", "http_async_client"): diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 4605c1f9..b8572e1f 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -4,7 +4,7 @@ from typing import Any, Optional from langchain_openai import ChatOpenAI from .base_client import BaseLLMClient -from .validators import validate_model +from .validators import validate_model, warn_if_unknown_model class UnifiedChatOpenAI(ChatOpenAI): @@ -41,6 +41,7 @@ class OpenAIClient(BaseLLMClient): def get_llm(self) -> Any: """Return configured ChatOpenAI instance.""" + warn_if_unknown_model(self.provider, self.model) llm_kwargs = {"model": self.model} if self.provider == "xai": diff --git a/tradingagents/llm_clients/validators.py b/tradingagents/llm_clients/validators.py index 1e2388b3..286bc6ac 100644 --- a/tradingagents/llm_clients/validators.py +++ b/tradingagents/llm_clients/validators.py @@ -4,6 +4,8 @@ Only validates model names - does NOT enforce limits. Let LLM providers use their own defaults for unspecified params. """ +import warnings + VALID_MODELS = { "openai": [ # GPT-5 series @@ -50,6 +52,15 @@ VALID_MODELS = { ], } +PROVIDER_DISPLAY_NAMES = { + "openai": "OpenAI", + "anthropic": "Anthropic", + "google": "Google", + "xai": "xAI", + "ollama": "Ollama", + "openrouter": "OpenRouter", +} + def validate_model(provider: str, model: str) -> bool: """Check if model name is valid for the given provider. @@ -65,3 +76,19 @@ def validate_model(provider: str, model: str) -> bool: return True return model in VALID_MODELS[provider_lower] + + +def warn_if_unknown_model(provider: str, model: str) -> bool: + """Warn for unknown models while allowing execution to continue.""" + is_valid = validate_model(provider, model) + if not is_valid: + provider_name = PROVIDER_DISPLAY_NAMES.get(provider.lower(), provider) + warnings.warn( + ( + f"Unknown {provider_name} model '{model}'. " + "Continuing without blocking because providers may add models before this list is updated." + ), + UserWarning, + stacklevel=2, + ) + return is_valid