Tighten model provider validation

This commit is contained in:
hu.yelong 2026-03-19 15:49:29 +08:00
parent 897380dbe1
commit d252fcb5db
2 changed files with 47 additions and 42 deletions

View File

@ -16,47 +16,49 @@ class LLMClientModelValidationTests(unittest.TestCase):
str(warning_records[0].message), str(warning_records[0].message),
) )
def test_openai_client_warns_for_unknown_model(self): def test_client_warns_for_unknown_model(self):
model = "fake-openai-model" test_cases = [
{
"client_class": OpenAIClient,
"patch_target": "tradingagents.llm_clients.openai_client.UnifiedChatOpenAI",
"provider_name": "OpenAI",
"model": "fake-openai-model",
"kwargs": {},
},
{
"client_class": AnthropicClient,
"patch_target": "tradingagents.llm_clients.anthropic_client.ChatAnthropic",
"provider_name": "Anthropic",
"model": "fake-claude-model",
"kwargs": {},
},
{
"client_class": GoogleClient,
"patch_target": "tradingagents.llm_clients.google_client.NormalizedChatGoogleGenerativeAI",
"provider_name": "Google",
"model": "fake-gemini-model",
"kwargs": {},
},
{
"client_class": OpenAIClient,
"patch_target": "tradingagents.llm_clients.openai_client.UnifiedChatOpenAI",
"provider_name": "opena",
"model": "gpt-5-mini",
"kwargs": {"provider": "opena"},
},
]
with patch( for case in test_cases:
"tradingagents.llm_clients.openai_client.UnifiedChatOpenAI", with self.subTest(provider=case["provider_name"]):
side_effect=lambda **kwargs: kwargs, with patch(case["patch_target"], side_effect=lambda **kwargs: kwargs):
): with warnings.catch_warnings(record=True) as warning_records:
with warnings.catch_warnings(record=True) as warning_records: warnings.simplefilter("always")
warnings.simplefilter("always") llm = case["client_class"](case["model"], **case["kwargs"]).get_llm()
llm = OpenAIClient(model).get_llm()
self.assertEqual(llm["model"], model) self.assertEqual(llm["model"], case["model"])
self.assert_single_unknown_model_warning(warning_records, "OpenAI", model) self.assert_single_unknown_model_warning(
warning_records, case["provider_name"], case["model"]
def test_anthropic_client_warns_for_unknown_model(self): )
model = "fake-claude-model"
with patch(
"tradingagents.llm_clients.anthropic_client.ChatAnthropic",
side_effect=lambda **kwargs: kwargs,
):
with warnings.catch_warnings(record=True) as warning_records:
warnings.simplefilter("always")
llm = AnthropicClient(model).get_llm()
self.assertEqual(llm["model"], model)
self.assert_single_unknown_model_warning(warning_records, "Anthropic", model)
def test_google_client_warns_for_unknown_model(self):
model = "fake-gemini-model"
with patch(
"tradingagents.llm_clients.google_client.NormalizedChatGoogleGenerativeAI",
side_effect=lambda **kwargs: kwargs,
):
with warnings.catch_warnings(record=True) as warning_records:
warnings.simplefilter("always")
llm = GoogleClient(model).get_llm()
self.assertEqual(llm["model"], model)
self.assert_single_unknown_model_warning(warning_records, "Google", model)
def test_openai_client_does_not_warn_for_known_model(self): def test_openai_client_does_not_warn_for_known_model(self):
with patch( with patch(

View File

@ -61,6 +61,9 @@ PROVIDER_DISPLAY_NAMES = {
"openrouter": "OpenRouter", "openrouter": "OpenRouter",
} }
PERMISSIVE_PROVIDERS = {"ollama", "openrouter"}
KNOWN_PROVIDERS = set(VALID_MODELS) | PERMISSIVE_PROVIDERS
def validate_model(provider: str, model: str) -> bool: def validate_model(provider: str, model: str) -> bool:
"""Check if model name is valid for the given provider. """Check if model name is valid for the given provider.
@ -69,11 +72,11 @@ def validate_model(provider: str, model: str) -> bool:
""" """
provider_lower = provider.lower() provider_lower = provider.lower()
if provider_lower in ("ollama", "openrouter"): if provider_lower in PERMISSIVE_PROVIDERS:
return True return True
if provider_lower not in VALID_MODELS: if provider_lower not in KNOWN_PROVIDERS:
return True return False
return model in VALID_MODELS[provider_lower] return model in VALID_MODELS[provider_lower]