TradingAgents/tradingagents/llm_clients/openai_client.py

101 lines
3.6 KiB
Python

import os
from typing import Any, Optional
from langchain_openai import ChatOpenAI
from .base_client import BaseLLMClient, normalize_content
from .validators import validate_model
class NormalizedChatOpenAI(ChatOpenAI):
"""ChatOpenAI with normalized content output.
The Responses API returns content as a list of typed blocks
(reasoning, text, etc.). This normalizes to string for consistent
downstream handling.
"""
def invoke(self, input, config=None, **kwargs):
return normalize_content(super().invoke(input, config, **kwargs))
# Kwargs forwarded from user config to ChatOpenAI
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "reasoning_effort",
"api_key", "callbacks", "http_client", "http_async_client",
)
# Provider base URLs and API key env vars
_PROVIDER_CONFIG = {
"xai": ("https://api.x.ai/v1", ("XAI_API_KEY",)),
"openrouter": ("https://openrouter.ai/api/v1", ("OPENROUTER_API_KEY",)),
"deepseek": ("https://api.deepseek.com/v1", ("DEEPSEEK_API_KEY",)),
"kimi": ("https://api.moonshot.cn/v1", ("KIMI_API_KEY", "MOONSHOT_API_KEY")),
"ollama": ("http://localhost:11434/v1", ()),
}
class OpenAIClient(BaseLLMClient):
"""Client for OpenAI-compatible providers.
For native OpenAI models, uses the Responses API (/v1/responses) which
supports reasoning_effort with function tools across all model families
(GPT-4.1, GPT-5). Third-party compatible providers (xAI, OpenRouter,
DeepSeek, Kimi, Ollama) use standard Chat Completions.
"""
def __init__(
self,
model: str,
base_url: Optional[str] = None,
provider: str = "openai",
**kwargs,
):
super().__init__(model, base_url, **kwargs)
self.provider = provider.lower()
def get_llm(self) -> Any:
"""Return configured ChatOpenAI instance."""
self.warn_if_unknown_model()
llm_kwargs = {"model": self.model}
# Provider-specific base URL and auth
if self.provider in _PROVIDER_CONFIG:
default_base_url, api_key_envs = _PROVIDER_CONFIG[self.provider]
llm_kwargs["base_url"] = self.base_url or default_base_url
if api_key_envs:
resolved_api_key = self.kwargs.get("api_key")
for api_key_env in api_key_envs:
if resolved_api_key:
break
api_key = os.environ.get(api_key_env)
if api_key:
resolved_api_key = api_key
break
if not resolved_api_key:
api_key_env_list = ", ".join(api_key_envs)
raise ValueError(
f"Missing API key for provider '{self.provider}'. "
f"Set one of: {api_key_env_list}, or pass api_key explicitly."
)
llm_kwargs["api_key"] = resolved_api_key
else:
llm_kwargs["api_key"] = "ollama"
elif self.base_url:
llm_kwargs["base_url"] = self.base_url
# Forward user-provided kwargs
for key in _PASSTHROUGH_KWARGS:
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
# Native OpenAI: use Responses API for consistent behavior across
# all model families. Third-party providers use Chat Completions.
if self.provider == "openai":
llm_kwargs["use_responses_api"] = True
return NormalizedChatOpenAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model for the provider."""
return validate_model(self.provider, self.model)