TradingAgents/tests/test_model_validation.py

66 lines
2.3 KiB
Python

import unittest
import warnings
from tradingagents.llm_clients.base_client import BaseLLMClient
from tradingagents.llm_clients.model_catalog import get_known_models, get_model_options
from tradingagents.llm_clients.validators import validate_model
class DummyLLMClient(BaseLLMClient):
def __init__(self, provider: str, model: str):
self.provider = provider
super().__init__(model)
def get_llm(self):
self.warn_if_unknown_model()
return object()
def validate_model(self) -> bool:
return validate_model(self.provider, self.model)
class ModelValidationTests(unittest.TestCase):
def test_local_llamacpp_models_are_exposed_in_cli_catalog(self):
quick_models = [value for _, value in get_model_options("ollama", "quick")]
deep_models = [value for _, value in get_model_options("ollama", "deep")]
for model in (
"Qwen3.5-27B",
"Qwen3.5-35B-3A",
"Qwen3.5-35B-A3B",
"Qwen3.5-122B",
):
with self.subTest(model=model):
self.assertIn(model, quick_models + deep_models)
def test_cli_catalog_models_are_all_validator_approved(self):
for provider, models in get_known_models().items():
if provider in ("ollama", "openrouter"):
continue
for model in models:
with self.subTest(provider=provider, model=model):
self.assertTrue(validate_model(provider, model))
def test_unknown_model_emits_warning_for_strict_provider(self):
client = DummyLLMClient("openai", "not-a-real-openai-model")
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
client.get_llm()
self.assertEqual(len(caught), 1)
self.assertIn("not-a-real-openai-model", str(caught[0].message))
self.assertIn("openai", str(caught[0].message))
def test_openrouter_and_ollama_accept_custom_models_without_warning(self):
for provider in ("openrouter", "ollama"):
client = DummyLLMClient(provider, "custom-model-name")
with self.subTest(provider=provider):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
client.get_llm()
self.assertEqual(caught, [])