58 lines
2.1 KiB
Python
58 lines
2.1 KiB
Python
import unittest
|
|
import warnings
|
|
|
|
from tradingagents.llm_clients.base_client import BaseLLMClient
|
|
from tradingagents.llm_clients.model_catalog import get_known_models
|
|
from tradingagents.llm_clients.validators import validate_model
|
|
|
|
|
|
class DummyLLMClient(BaseLLMClient):
|
|
def __init__(self, provider: str, model: str):
|
|
self.provider = provider
|
|
super().__init__(model)
|
|
|
|
def get_llm(self):
|
|
self.warn_if_unknown_model()
|
|
return object()
|
|
|
|
def validate_model(self) -> bool:
|
|
return validate_model(self.provider, self.model)
|
|
|
|
|
|
class ModelValidationTests(unittest.TestCase):
|
|
def test_cli_catalog_models_are_all_validator_approved(self):
|
|
for provider, models in get_known_models().items():
|
|
if provider in ("ollama", "openrouter"):
|
|
continue
|
|
|
|
for model in models:
|
|
with self.subTest(provider=provider, model=model):
|
|
self.assertTrue(validate_model(provider, model))
|
|
|
|
def test_unknown_model_emits_warning_for_strict_provider(self):
|
|
client = DummyLLMClient("openai", "not-a-real-openai-model")
|
|
|
|
with warnings.catch_warnings(record=True) as caught:
|
|
warnings.simplefilter("always")
|
|
client.get_llm()
|
|
|
|
self.assertEqual(len(caught), 1)
|
|
self.assertIn("not-a-real-openai-model", str(caught[0].message))
|
|
self.assertIn("openai", str(caught[0].message))
|
|
|
|
def test_openrouter_and_ollama_accept_custom_models_without_warning(self):
|
|
for provider in ("openrouter", "ollama"):
|
|
client = DummyLLMClient(provider, "custom-model-name")
|
|
|
|
with self.subTest(provider=provider):
|
|
with warnings.catch_warnings(record=True) as caught:
|
|
warnings.simplefilter("always")
|
|
client.get_llm()
|
|
|
|
self.assertEqual(caught, [])
|
|
|
|
def test_minimax_anthropic_compatible_models_are_known(self):
|
|
for model in ("MiniMax-M2.7-highspeed", "MiniMax-M2.7"):
|
|
with self.subTest(model=model):
|
|
self.assertTrue(validate_model("anthropic", model))
|