Tighten model provider validation
This commit is contained in:
parent
897380dbe1
commit
d252fcb5db
|
|
@ -16,47 +16,49 @@ class LLMClientModelValidationTests(unittest.TestCase):
|
||||||
str(warning_records[0].message),
|
str(warning_records[0].message),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_openai_client_warns_for_unknown_model(self):
|
def test_client_warns_for_unknown_model(self):
|
||||||
model = "fake-openai-model"
|
test_cases = [
|
||||||
|
{
|
||||||
|
"client_class": OpenAIClient,
|
||||||
|
"patch_target": "tradingagents.llm_clients.openai_client.UnifiedChatOpenAI",
|
||||||
|
"provider_name": "OpenAI",
|
||||||
|
"model": "fake-openai-model",
|
||||||
|
"kwargs": {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"client_class": AnthropicClient,
|
||||||
|
"patch_target": "tradingagents.llm_clients.anthropic_client.ChatAnthropic",
|
||||||
|
"provider_name": "Anthropic",
|
||||||
|
"model": "fake-claude-model",
|
||||||
|
"kwargs": {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"client_class": GoogleClient,
|
||||||
|
"patch_target": "tradingagents.llm_clients.google_client.NormalizedChatGoogleGenerativeAI",
|
||||||
|
"provider_name": "Google",
|
||||||
|
"model": "fake-gemini-model",
|
||||||
|
"kwargs": {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"client_class": OpenAIClient,
|
||||||
|
"patch_target": "tradingagents.llm_clients.openai_client.UnifiedChatOpenAI",
|
||||||
|
"provider_name": "opena",
|
||||||
|
"model": "gpt-5-mini",
|
||||||
|
"kwargs": {"provider": "opena"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
with patch(
|
for case in test_cases:
|
||||||
"tradingagents.llm_clients.openai_client.UnifiedChatOpenAI",
|
with self.subTest(provider=case["provider_name"]):
|
||||||
side_effect=lambda **kwargs: kwargs,
|
with patch(case["patch_target"], side_effect=lambda **kwargs: kwargs):
|
||||||
):
|
|
||||||
with warnings.catch_warnings(record=True) as warning_records:
|
with warnings.catch_warnings(record=True) as warning_records:
|
||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
llm = OpenAIClient(model).get_llm()
|
llm = case["client_class"](case["model"], **case["kwargs"]).get_llm()
|
||||||
|
|
||||||
self.assertEqual(llm["model"], model)
|
self.assertEqual(llm["model"], case["model"])
|
||||||
self.assert_single_unknown_model_warning(warning_records, "OpenAI", model)
|
self.assert_single_unknown_model_warning(
|
||||||
|
warning_records, case["provider_name"], case["model"]
|
||||||
def test_anthropic_client_warns_for_unknown_model(self):
|
)
|
||||||
model = "fake-claude-model"
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"tradingagents.llm_clients.anthropic_client.ChatAnthropic",
|
|
||||||
side_effect=lambda **kwargs: kwargs,
|
|
||||||
):
|
|
||||||
with warnings.catch_warnings(record=True) as warning_records:
|
|
||||||
warnings.simplefilter("always")
|
|
||||||
llm = AnthropicClient(model).get_llm()
|
|
||||||
|
|
||||||
self.assertEqual(llm["model"], model)
|
|
||||||
self.assert_single_unknown_model_warning(warning_records, "Anthropic", model)
|
|
||||||
|
|
||||||
def test_google_client_warns_for_unknown_model(self):
|
|
||||||
model = "fake-gemini-model"
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"tradingagents.llm_clients.google_client.NormalizedChatGoogleGenerativeAI",
|
|
||||||
side_effect=lambda **kwargs: kwargs,
|
|
||||||
):
|
|
||||||
with warnings.catch_warnings(record=True) as warning_records:
|
|
||||||
warnings.simplefilter("always")
|
|
||||||
llm = GoogleClient(model).get_llm()
|
|
||||||
|
|
||||||
self.assertEqual(llm["model"], model)
|
|
||||||
self.assert_single_unknown_model_warning(warning_records, "Google", model)
|
|
||||||
|
|
||||||
def test_openai_client_does_not_warn_for_known_model(self):
|
def test_openai_client_does_not_warn_for_known_model(self):
|
||||||
with patch(
|
with patch(
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,9 @@ PROVIDER_DISPLAY_NAMES = {
|
||||||
"openrouter": "OpenRouter",
|
"openrouter": "OpenRouter",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PERMISSIVE_PROVIDERS = {"ollama", "openrouter"}
|
||||||
|
KNOWN_PROVIDERS = set(VALID_MODELS) | PERMISSIVE_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
def validate_model(provider: str, model: str) -> bool:
|
def validate_model(provider: str, model: str) -> bool:
|
||||||
"""Check if model name is valid for the given provider.
|
"""Check if model name is valid for the given provider.
|
||||||
|
|
@ -69,11 +72,11 @@ def validate_model(provider: str, model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
if provider_lower in ("ollama", "openrouter"):
|
if provider_lower in PERMISSIVE_PROVIDERS:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if provider_lower not in VALID_MODELS:
|
if provider_lower not in KNOWN_PROVIDERS:
|
||||||
return True
|
return False
|
||||||
|
|
||||||
return model in VALID_MODELS[provider_lower]
|
return model in VALID_MODELS[provider_lower]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue