112 lines
4.2 KiB
Python
112 lines
4.2 KiB
Python
"""Unit tests for OpenAI client."""
|
|
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from tradingagents.llm_clients.openai_client import OpenAIClient, UnifiedChatOpenAI
|
|
|
|
|
|
class TestOpenAIClient:
|
|
"""Tests for the OpenAI client."""
|
|
|
|
@pytest.mark.unit
|
|
def test_init_with_provider(self):
|
|
"""Test client initialization with provider."""
|
|
client = OpenAIClient("gpt-4", provider="openai")
|
|
assert client.model == "gpt-4"
|
|
assert client.provider == "openai"
|
|
assert client.base_url is None
|
|
|
|
@pytest.mark.unit
|
|
def test_init_with_base_url(self):
|
|
"""Test client initialization with base URL."""
|
|
client = OpenAIClient("gpt-4", base_url="https://custom.api.com/v1", provider="openai")
|
|
assert client.base_url == "https://custom.api.com/v1"
|
|
|
|
@pytest.mark.unit
|
|
def test_provider_lowercase(self):
|
|
"""Test that provider is lowercased."""
|
|
client = OpenAIClient("gpt-4", provider="OpenAI")
|
|
assert client.provider == "openai"
|
|
|
|
|
|
class TestUnifiedChatOpenAI:
|
|
"""Tests for the UnifiedChatOpenAI class."""
|
|
|
|
@pytest.mark.unit
|
|
def test_is_reasoning_model_o1(self):
|
|
"""Test reasoning model detection for o1 series."""
|
|
assert UnifiedChatOpenAI._is_reasoning_model("o1-preview")
|
|
assert UnifiedChatOpenAI._is_reasoning_model("o1-mini")
|
|
assert UnifiedChatOpenAI._is_reasoning_model("O1-PRO")
|
|
|
|
@pytest.mark.unit
|
|
def test_is_reasoning_model_o3(self):
|
|
"""Test reasoning model detection for o3 series."""
|
|
assert UnifiedChatOpenAI._is_reasoning_model("o3-mini")
|
|
assert UnifiedChatOpenAI._is_reasoning_model("O3-MINI")
|
|
|
|
@pytest.mark.unit
|
|
def test_is_reasoning_model_gpt5(self):
|
|
"""Test reasoning model detection for GPT-5 series."""
|
|
assert UnifiedChatOpenAI._is_reasoning_model("gpt-5")
|
|
assert UnifiedChatOpenAI._is_reasoning_model("gpt-5.2")
|
|
assert UnifiedChatOpenAI._is_reasoning_model("GPT-5-MINI")
|
|
|
|
@pytest.mark.unit
|
|
def test_is_not_reasoning_model(self):
|
|
"""Test that standard models are not detected as reasoning models."""
|
|
assert not UnifiedChatOpenAI._is_reasoning_model("gpt-4o")
|
|
assert not UnifiedChatOpenAI._is_reasoning_model("gpt-4-turbo")
|
|
assert not UnifiedChatOpenAI._is_reasoning_model("gpt-3.5-turbo")
|
|
|
|
|
|
class TestOpenAIClientGetLLM:
|
|
"""Tests for OpenAI client get_llm method."""
|
|
|
|
@pytest.mark.unit
|
|
@patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"})
|
|
def test_get_llm_openai(self):
|
|
"""Test getting LLM for OpenAI provider."""
|
|
client = OpenAIClient("gpt-4", provider="openai")
|
|
llm = client.get_llm()
|
|
assert llm.model == "gpt-4"
|
|
|
|
@pytest.mark.unit
|
|
@patch.dict("os.environ", {"XAI_API_KEY": "test-xai-key"})
|
|
def test_get_llm_xai_uses_correct_url(self):
|
|
"""Test that xAI client uses correct base URL."""
|
|
client = OpenAIClient("grok-beta", provider="xai")
|
|
# Verify xAI base_url is configured
|
|
assert client.kwargs.get("base_url") is None # Not in kwargs, set in get_llm
|
|
|
|
@pytest.mark.unit
|
|
@patch.dict("os.environ", {"OPENROUTER_API_KEY": "test-or-key"})
|
|
def test_get_llm_openrouter_uses_correct_url(self):
|
|
"""Test that OpenRouter client uses correct base URL."""
|
|
client = OpenAIClient("gpt-4", provider="openrouter")
|
|
# Verify OpenRouter base_url is configured
|
|
assert client.kwargs.get("base_url") is None # Not in kwargs, set in get_llm
|
|
|
|
@pytest.mark.unit
|
|
def test_get_llm_ollama_uses_correct_url(self):
|
|
"""Test that Ollama client uses correct base URL."""
|
|
client = OpenAIClient("llama2", provider="ollama")
|
|
# Verify Ollama configuration
|
|
assert client.provider == "ollama"
|
|
|
|
@pytest.mark.unit
|
|
def test_get_llm_with_timeout(self):
|
|
"""Test that timeout is passed to LLM kwargs."""
|
|
client = OpenAIClient("gpt-4", provider="openai", timeout=60)
|
|
# Verify timeout was passed to kwargs
|
|
assert client.kwargs.get("timeout") == 60
|
|
|
|
@pytest.mark.unit
|
|
def test_get_llm_with_max_retries(self):
|
|
"""Test that max_retries is passed to LLM."""
|
|
client = OpenAIClient("gpt-4", provider="openai", max_retries=3)
|
|
llm = client.get_llm()
|
|
assert llm.max_retries == 3
|