Normalize model validation inputs
This commit is contained in:
parent
e17db7bd35
commit
83ad742dec
|
|
@ -50,3 +50,6 @@ class ModelValidationTests(unittest.TestCase):
|
||||||
client.get_llm()
|
client.get_llm()
|
||||||
|
|
||||||
self.assertEqual(caught, [])
|
self.assertEqual(caught, [])
|
||||||
|
|
||||||
|
def test_validator_accepts_known_model_with_surrounding_whitespace(self):
|
||||||
|
self.assertTrue(validate_model(" openai ", " gpt-5.4 "))
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,8 @@ def validate_model(provider: str, model: str) -> bool:
|
||||||
|
|
||||||
For ollama, openrouter - any model is accepted.
|
For ollama, openrouter - any model is accepted.
|
||||||
"""
|
"""
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower().strip()
|
||||||
|
model_name = model.strip()
|
||||||
|
|
||||||
if provider_lower in ("ollama", "openrouter"):
|
if provider_lower in ("ollama", "openrouter"):
|
||||||
return True
|
return True
|
||||||
|
|
@ -23,4 +24,4 @@ def validate_model(provider: str, model: str) -> bool:
|
||||||
if provider_lower not in VALID_MODELS:
|
if provider_lower not in VALID_MODELS:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return model in VALID_MODELS[provider_lower]
|
return model_name in VALID_MODELS[provider_lower]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue