diff --git a/tests/test_llm_clients.py b/tests/test_llm_clients.py index e7c0b9a0..6b2b8cc6 100644 --- a/tests/test_llm_clients.py +++ b/tests/test_llm_clients.py @@ -16,47 +16,49 @@ class LLMClientModelValidationTests(unittest.TestCase): str(warning_records[0].message), ) - def test_openai_client_warns_for_unknown_model(self): - model = "fake-openai-model" + def test_client_warns_for_unknown_model(self): + test_cases = [ + { + "client_class": OpenAIClient, + "patch_target": "tradingagents.llm_clients.openai_client.UnifiedChatOpenAI", + "provider_name": "OpenAI", + "model": "fake-openai-model", + "kwargs": {}, + }, + { + "client_class": AnthropicClient, + "patch_target": "tradingagents.llm_clients.anthropic_client.ChatAnthropic", + "provider_name": "Anthropic", + "model": "fake-claude-model", + "kwargs": {}, + }, + { + "client_class": GoogleClient, + "patch_target": "tradingagents.llm_clients.google_client.NormalizedChatGoogleGenerativeAI", + "provider_name": "Google", + "model": "fake-gemini-model", + "kwargs": {}, + }, + { + "client_class": OpenAIClient, + "patch_target": "tradingagents.llm_clients.openai_client.UnifiedChatOpenAI", + "provider_name": "opena", + "model": "gpt-5-mini", + "kwargs": {"provider": "opena"}, + }, + ] - with patch( - "tradingagents.llm_clients.openai_client.UnifiedChatOpenAI", - side_effect=lambda **kwargs: kwargs, - ): - with warnings.catch_warnings(record=True) as warning_records: - warnings.simplefilter("always") - llm = OpenAIClient(model).get_llm() + for case in test_cases: + with self.subTest(provider=case["provider_name"]): + with patch(case["patch_target"], side_effect=lambda **kwargs: kwargs): + with warnings.catch_warnings(record=True) as warning_records: + warnings.simplefilter("always") + llm = case["client_class"](case["model"], **case["kwargs"]).get_llm() - self.assertEqual(llm["model"], model) - self.assert_single_unknown_model_warning(warning_records, "OpenAI", model) - - def test_anthropic_client_warns_for_unknown_model(self): - model = "fake-claude-model" - - with patch( - "tradingagents.llm_clients.anthropic_client.ChatAnthropic", - side_effect=lambda **kwargs: kwargs, - ): - with warnings.catch_warnings(record=True) as warning_records: - warnings.simplefilter("always") - llm = AnthropicClient(model).get_llm() - - self.assertEqual(llm["model"], model) - self.assert_single_unknown_model_warning(warning_records, "Anthropic", model) - - def test_google_client_warns_for_unknown_model(self): - model = "fake-gemini-model" - - with patch( - "tradingagents.llm_clients.google_client.NormalizedChatGoogleGenerativeAI", - side_effect=lambda **kwargs: kwargs, - ): - with warnings.catch_warnings(record=True) as warning_records: - warnings.simplefilter("always") - llm = GoogleClient(model).get_llm() - - self.assertEqual(llm["model"], model) - self.assert_single_unknown_model_warning(warning_records, "Google", model) + self.assertEqual(llm["model"], case["model"]) + self.assert_single_unknown_model_warning( + warning_records, case["provider_name"], case["model"] + ) def test_openai_client_does_not_warn_for_known_model(self): with patch( diff --git a/tradingagents/llm_clients/validators.py b/tradingagents/llm_clients/validators.py index 286bc6ac..dac2d268 100644 --- a/tradingagents/llm_clients/validators.py +++ b/tradingagents/llm_clients/validators.py @@ -61,6 +61,9 @@ PROVIDER_DISPLAY_NAMES = { "openrouter": "OpenRouter", } +PERMISSIVE_PROVIDERS = {"ollama", "openrouter"} +KNOWN_PROVIDERS = set(VALID_MODELS) | PERMISSIVE_PROVIDERS + def validate_model(provider: str, model: str) -> bool: """Check if model name is valid for the given provider. @@ -69,11 +72,11 @@ def validate_model(provider: str, model: str) -> bool: """ provider_lower = provider.lower() - if provider_lower in ("ollama", "openrouter"): + if provider_lower in PERMISSIVE_PROVIDERS: return True - if provider_lower not in VALID_MODELS: - return True + if provider_lower not in KNOWN_PROVIDERS: + return False return model in VALID_MODELS[provider_lower]