83 lines
3.1 KiB
Python
83 lines
3.1 KiB
Python
"""Unit tests for LLM client factory."""
|
|
|
|
|
|
import pytest
|
|
|
|
from tradingagents.llm_clients.anthropic_client import AnthropicClient
|
|
from tradingagents.llm_clients.factory import create_llm_client
|
|
from tradingagents.llm_clients.google_client import GoogleClient
|
|
from tradingagents.llm_clients.openai_client import OpenAIClient
|
|
|
|
|
|
class TestCreateLLMClient:
|
|
"""Tests for the LLM client factory function."""
|
|
|
|
@pytest.mark.unit
|
|
def test_create_openai_client(self):
|
|
"""Test creating an OpenAI client."""
|
|
client = create_llm_client("openai", "gpt-4")
|
|
assert isinstance(client, OpenAIClient)
|
|
assert client.model == "gpt-4"
|
|
assert client.provider == "openai"
|
|
|
|
@pytest.mark.unit
|
|
def test_create_openai_client_case_insensitive(self):
|
|
"""Test that provider names are case insensitive."""
|
|
client = create_llm_client("OpenAI", "gpt-4o")
|
|
assert isinstance(client, OpenAIClient)
|
|
assert client.provider == "openai"
|
|
|
|
@pytest.mark.unit
|
|
def test_create_anthropic_client(self):
|
|
"""Test creating an Anthropic client."""
|
|
client = create_llm_client("anthropic", "claude-3-opus")
|
|
assert isinstance(client, AnthropicClient)
|
|
assert client.model == "claude-3-opus"
|
|
|
|
@pytest.mark.unit
|
|
def test_create_google_client(self):
|
|
"""Test creating a Google client."""
|
|
client = create_llm_client("google", "gemini-pro")
|
|
assert isinstance(client, GoogleClient)
|
|
assert client.model == "gemini-pro"
|
|
|
|
@pytest.mark.unit
|
|
def test_create_xai_client(self):
|
|
"""Test creating an xAI client (uses OpenAI-compatible API)."""
|
|
client = create_llm_client("xai", "grok-beta")
|
|
assert isinstance(client, OpenAIClient)
|
|
assert client.provider == "xai"
|
|
|
|
@pytest.mark.unit
|
|
def test_create_ollama_client(self):
|
|
"""Test creating an Ollama client."""
|
|
client = create_llm_client("ollama", "llama2")
|
|
assert isinstance(client, OpenAIClient)
|
|
assert client.provider == "ollama"
|
|
|
|
@pytest.mark.unit
|
|
def test_create_openrouter_client(self):
|
|
"""Test creating an OpenRouter client."""
|
|
client = create_llm_client("openrouter", "gpt-4")
|
|
assert isinstance(client, OpenAIClient)
|
|
assert client.provider == "openrouter"
|
|
|
|
@pytest.mark.unit
|
|
def test_unsupported_provider_raises(self):
|
|
"""Test that unsupported provider raises ValueError."""
|
|
with pytest.raises(ValueError, match="Unsupported LLM provider"):
|
|
create_llm_client("unknown_provider", "model-name")
|
|
|
|
@pytest.mark.unit
|
|
def test_create_client_with_base_url(self):
|
|
"""Test creating a client with custom base URL."""
|
|
client = create_llm_client("openai", "gpt-4", base_url="https://custom.api.com/v1")
|
|
assert client.base_url == "https://custom.api.com/v1"
|
|
|
|
@pytest.mark.unit
|
|
def test_create_client_with_kwargs(self):
|
|
"""Test creating a client with additional kwargs."""
|
|
client = create_llm_client("openai", "gpt-4", timeout=30, max_retries=5)
|
|
assert client.kwargs.get("timeout") == 30
|
|
assert client.kwargs.get("max_retries") == 5
|