TradingAgents/tests/test_model_validation.py

58 lines
2.1 KiB
Python

import unittest
import warnings
from tradingagents.llm_clients.base_client import BaseLLMClient
from tradingagents.llm_clients.model_catalog import get_known_models
from tradingagents.llm_clients.validators import validate_model
class DummyLLMClient(BaseLLMClient):
def __init__(self, provider: str, model: str):
self.provider = provider
super().__init__(model)
def get_llm(self):
self.warn_if_unknown_model()
return object()
def validate_model(self) -> bool:
return validate_model(self.provider, self.model)
class ModelValidationTests(unittest.TestCase):
def test_cli_catalog_models_are_all_validator_approved(self):
for provider, models in get_known_models().items():
if provider in ("ollama", "openrouter"):
continue
for model in models:
with self.subTest(provider=provider, model=model):
self.assertTrue(validate_model(provider, model))
def test_unknown_model_emits_warning_for_strict_provider(self):
client = DummyLLMClient("openai", "not-a-real-openai-model")
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
client.get_llm()
self.assertEqual(len(caught), 1)
self.assertIn("not-a-real-openai-model", str(caught[0].message))
self.assertIn("openai", str(caught[0].message))
def test_openrouter_and_ollama_accept_custom_models_without_warning(self):
for provider in ("openrouter", "ollama"):
client = DummyLLMClient(provider, "custom-model-name")
with self.subTest(provider=provider):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
client.get_llm()
self.assertEqual(caught, [])
def test_minimax_anthropic_compatible_models_are_known(self):
for model in ("MiniMax-M2.7-highspeed", "MiniMax-M2.7"):
with self.subTest(model=model):
self.assertTrue(validate_model("anthropic", model))