feat: add Azure OpenAI support
- Add AzureClient class using langchain_openai.AzureChatOpenAI - Add Azure to LLM provider factory - Add Azure deployment options to CLI selectors - Support environment variables: AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, OPENAI_API_VERSION Fixes #334
This commit is contained in:
parent
f362a160c3
commit
5bc20dfe90
13
cli/utils.py
13
cli/utils.py
|
|
@ -164,6 +164,12 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
],
|
||||
"azure": [
|
||||
("GPT-4o (your deployment)", "gpt-4o"),
|
||||
("GPT-4o-mini (your deployment)", "gpt-4o-mini"),
|
||||
("GPT-4 Turbo (your deployment)", "gpt-4-turbo"),
|
||||
("Custom deployment name", "__custom__"),
|
||||
],
|
||||
}
|
||||
|
||||
choice = questionary.select(
|
||||
|
|
@ -231,6 +237,12 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
],
|
||||
"azure": [
|
||||
("GPT-4o (your deployment)", "gpt-4o"),
|
||||
("GPT-4o-mini (your deployment)", "gpt-4o-mini"),
|
||||
("GPT-4 Turbo (your deployment)", "gpt-4-turbo"),
|
||||
("Custom deployment name", "__custom__"),
|
||||
],
|
||||
}
|
||||
|
||||
choice = questionary.select(
|
||||
|
|
@ -265,6 +277,7 @@ def select_llm_provider() -> tuple[str, str]:
|
|||
("xAI", "https://api.x.ai/v1"),
|
||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
("Azure OpenAI", "azure"),
|
||||
]
|
||||
|
||||
choice = questionary.select(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
"""Azure OpenAI client implementation."""
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
|
||||
|
||||
class AzureClient(BaseLLMClient):
|
||||
"""Client for Azure OpenAI provider.
|
||||
|
||||
Required environment variables:
|
||||
AZURE_OPENAI_API_KEY: Your Azure OpenAI API key
|
||||
AZURE_OPENAI_ENDPOINT: Your Azure endpoint (e.g., https://your-resource.openai.azure.com/)
|
||||
OPENAI_API_VERSION: API version (e.g., 2024-02-15-preview)
|
||||
|
||||
Optional:
|
||||
AZURE_DEPLOYMENT_NAME: Your deployment name (can also be passed as model parameter)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
base_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
# In Azure, model is the deployment name
|
||||
self.deployment_name = model
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured AzureChatOpenAI instance."""
|
||||
llm_kwargs = {
|
||||
"azure_deployment": self.deployment_name,
|
||||
}
|
||||
|
||||
# Pass through Azure-specific kwargs
|
||||
for key in ("timeout", "max_retries", "callbacks", "api_version",
|
||||
"azure_endpoint", "api_key", "http_client", "http_async_client"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
# Environment variable fallbacks
|
||||
if "azure_endpoint" not in llm_kwargs:
|
||||
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
if endpoint:
|
||||
llm_kwargs["azure_endpoint"] = endpoint
|
||||
|
||||
if "api_key" not in llm_kwargs:
|
||||
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||
if api_key:
|
||||
llm_kwargs["api_key"] = api_key
|
||||
|
||||
if "api_version" not in llm_kwargs:
|
||||
api_version = os.environ.get("OPENAI_API_VERSION", "2024-02-15-preview")
|
||||
llm_kwargs["api_version"] = api_version
|
||||
|
||||
return AzureChatOpenAI(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Validate Azure deployment name (always valid, Azure validates at runtime)."""
|
||||
return bool(self.deployment_name)
|
||||
|
|
@ -4,6 +4,7 @@ from .base_client import BaseLLMClient
|
|||
from .openai_client import OpenAIClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
from .google_client import GoogleClient
|
||||
from .azure_client import AzureClient
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
|
|
@ -15,8 +16,8 @@ def create_llm_client(
|
|||
"""Create an LLM client for the specified provider.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
||||
model: Model name/identifier
|
||||
provider: LLM provider (openai, anthropic, google, azure, xai, ollama, openrouter)
|
||||
model: Model name/identifier (for Azure, use deployment name)
|
||||
base_url: Optional base URL for API endpoint
|
||||
**kwargs: Additional provider-specific arguments
|
||||
- http_client: Custom httpx.Client for SSL proxy or certificate customization
|
||||
|
|
@ -46,4 +47,7 @@ def create_llm_client(
|
|||
if provider_lower == "google":
|
||||
return GoogleClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "azure":
|
||||
return AzureClient(model, base_url, **kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue