TradingAgents/tests/test_llm_provider_support.py

90 lines
3.8 KiB
Python

import os
import unittest
from unittest.mock import patch
from tradingagents.llm_clients.factory import create_llm_client
from tradingagents.llm_clients.openai_client import OpenAIClient
from tradingagents.llm_clients.validators import validate_model
class LLMProviderSupportTests(unittest.TestCase):
def test_factory_supports_deepseek_and_kimi(self):
deepseek_client = create_llm_client("deepseek", "deepseek-chat")
kimi_client = create_llm_client("kimi", "kimi-latest")
self.assertIsInstance(deepseek_client, OpenAIClient)
self.assertIsInstance(kimi_client, OpenAIClient)
@patch("tradingagents.llm_clients.openai_client.NormalizedChatOpenAI")
def test_deepseek_uses_expected_base_url_and_key(self, mock_chat_openai):
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "deepseek-test-key"}, clear=False):
client = OpenAIClient("deepseek-chat", provider="deepseek")
client.get_llm()
kwargs = mock_chat_openai.call_args.kwargs
self.assertEqual(kwargs["base_url"], "https://api.deepseek.com/v1")
self.assertEqual(kwargs["api_key"], "deepseek-test-key")
self.assertEqual(kwargs["model"], "deepseek-chat")
@patch("tradingagents.llm_clients.openai_client.NormalizedChatOpenAI")
def test_deepseek_supports_custom_base_url(self, mock_chat_openai):
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "deepseek-test-key"}, clear=False):
client = OpenAIClient(
"deepseek-chat",
provider="deepseek",
base_url="https://proxy.example.com/v1",
)
client.get_llm()
kwargs = mock_chat_openai.call_args.kwargs
self.assertEqual(kwargs["base_url"], "https://proxy.example.com/v1")
self.assertEqual(kwargs["api_key"], "deepseek-test-key")
self.assertEqual(kwargs["model"], "deepseek-chat")
def test_deepseek_missing_key_raises_clear_error(self):
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": ""}, clear=False):
client = OpenAIClient("deepseek-chat", provider="deepseek")
with self.assertRaisesRegex(ValueError, "Missing API key for provider 'deepseek'"):
client.get_llm()
@patch("tradingagents.llm_clients.openai_client.NormalizedChatOpenAI")
def test_kimi_prefers_kimi_api_key(self, mock_chat_openai):
with patch.dict(
os.environ,
{
"KIMI_API_KEY": "kimi-test-key",
"MOONSHOT_API_KEY": "moonshot-test-key",
},
clear=False,
):
client = OpenAIClient("kimi-latest", provider="kimi")
client.get_llm()
kwargs = mock_chat_openai.call_args.kwargs
self.assertEqual(kwargs["base_url"], "https://api.moonshot.cn/v1")
self.assertEqual(kwargs["api_key"], "kimi-test-key")
self.assertEqual(kwargs["model"], "kimi-latest")
@patch("tradingagents.llm_clients.openai_client.NormalizedChatOpenAI")
def test_kimi_falls_back_to_moonshot_key(self, mock_chat_openai):
with patch.dict(
os.environ,
{"KIMI_API_KEY": "", "MOONSHOT_API_KEY": "moonshot-test-key"},
clear=False,
):
client = OpenAIClient("kimi-thinking-preview", provider="kimi")
client.get_llm()
kwargs = mock_chat_openai.call_args.kwargs
self.assertEqual(kwargs["base_url"], "https://api.moonshot.cn/v1")
self.assertEqual(kwargs["api_key"], "moonshot-test-key")
self.assertEqual(kwargs["model"], "kimi-thinking-preview")
def test_validator_allows_fast_moving_compatibility_providers(self):
self.assertTrue(validate_model("deepseek", "any-model-name"))
self.assertTrue(validate_model("kimi", "any-model-name"))
if __name__ == "__main__":
unittest.main()