feat: add Azure OpenAI support

- Add AzureClient class using langchain_openai.AzureChatOpenAI
- Add Azure to LLM provider factory
- Add Azure deployment options to CLI selectors
- Support environment variables: AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, OPENAI_API_VERSION

Fixes #334
This commit is contained in:
阳虎 2026-03-16 08:18:36 +08:00
parent f362a160c3
commit 5bc20dfe90
3 changed files with 83 additions and 2 deletions

View File

@ -164,6 +164,12 @@ def select_shallow_thinking_agent(provider) -> str:
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
],
"azure": [
("GPT-4o (your deployment)", "gpt-4o"),
("GPT-4o-mini (your deployment)", "gpt-4o-mini"),
("GPT-4 Turbo (your deployment)", "gpt-4-turbo"),
("Custom deployment name", "__custom__"),
],
}
choice = questionary.select(
@ -231,6 +237,12 @@ def select_deep_thinking_agent(provider) -> str:
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("Qwen3:latest (8B, local)", "qwen3:latest"),
],
"azure": [
("GPT-4o (your deployment)", "gpt-4o"),
("GPT-4o-mini (your deployment)", "gpt-4o-mini"),
("GPT-4 Turbo (your deployment)", "gpt-4-turbo"),
("Custom deployment name", "__custom__"),
],
}
choice = questionary.select(
@ -265,6 +277,7 @@ def select_llm_provider() -> tuple[str, str]:
("xAI", "https://api.x.ai/v1"),
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
("Azure OpenAI", "azure"),
]
choice = questionary.select(

View File

@ -0,0 +1,64 @@
"""Azure OpenAI client implementation."""
import os
from typing import Any, Optional
from langchain_openai import AzureChatOpenAI
from .base_client import BaseLLMClient
class AzureClient(BaseLLMClient):
"""Client for Azure OpenAI provider.
Required environment variables:
AZURE_OPENAI_API_KEY: Your Azure OpenAI API key
AZURE_OPENAI_ENDPOINT: Your Azure endpoint (e.g., https://your-resource.openai.azure.com/)
OPENAI_API_VERSION: API version (e.g., 2024-02-15-preview)
Optional:
AZURE_DEPLOYMENT_NAME: Your deployment name (can also be passed as model parameter)
"""
def __init__(
self,
model: str,
base_url: Optional[str] = None,
**kwargs,
):
super().__init__(model, base_url, **kwargs)
# In Azure, model is the deployment name
self.deployment_name = model
def get_llm(self) -> Any:
"""Return configured AzureChatOpenAI instance."""
llm_kwargs = {
"azure_deployment": self.deployment_name,
}
# Pass through Azure-specific kwargs
for key in ("timeout", "max_retries", "callbacks", "api_version",
"azure_endpoint", "api_key", "http_client", "http_async_client"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
# Environment variable fallbacks
if "azure_endpoint" not in llm_kwargs:
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
if endpoint:
llm_kwargs["azure_endpoint"] = endpoint
if "api_key" not in llm_kwargs:
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
if api_key:
llm_kwargs["api_key"] = api_key
if "api_version" not in llm_kwargs:
api_version = os.environ.get("OPENAI_API_VERSION", "2024-02-15-preview")
llm_kwargs["api_version"] = api_version
return AzureChatOpenAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate Azure deployment name (always valid, Azure validates at runtime)."""
return bool(self.deployment_name)

View File

@ -4,6 +4,7 @@ from .base_client import BaseLLMClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
from .google_client import GoogleClient
from .azure_client import AzureClient
def create_llm_client(
@ -15,8 +16,8 @@ def create_llm_client(
"""Create an LLM client for the specified provider.
Args:
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
model: Model name/identifier
provider: LLM provider (openai, anthropic, google, azure, xai, ollama, openrouter)
model: Model name/identifier (for Azure, use deployment name)
base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments
- http_client: Custom httpx.Client for SSL proxy or certificate customization
@ -46,4 +47,7 @@ def create_llm_client(
if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs)
if provider_lower == "azure":
return AzureClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}")