fix: address Gemini review feedback for PR #397

- Fix Ollama URL normalization bug (trailing slash edge case)
- Use rstrip('/') before checking /v1 suffix
- Add DRY pattern for Azure env var fallbacks
- Add final newline to azure_client.py
- Use shared OLLAMA_MODELS/AZURE_MODELS constants to avoid duplicates
This commit is contained in:
阳虎 2026-03-19 12:14:38 +08:00
parent 06859104e8
commit 6bdea25402
4 changed files with 97 additions and 13 deletions

View File

@ -14,6 +14,22 @@ ANALYST_ORDER = [
("Fundamentals Analyst", AnalystType.FUNDAMENTALS),
]
# Shared model lists for providers with identical options in shallow/deep selections
# Azure: deployment names are user-specific, so we provide common defaults
AZURE_MODELS = [
("GPT-4o (your deployment)", "gpt-4o"),
("GPT-4o-mini (your deployment)", "gpt-4o-mini"),
("GPT-4 Turbo (your deployment)", "gpt-4-turbo"),
("Custom deployment name", "__custom__"),
]
# Ollama: same models for both shallow and deep (local inference)
OLLAMA_MODELS = [
("Qwen3:latest (8B, local)", "qwen3:latest"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
]
def get_ticker() -> str:
"""Prompt the user to enter a ticker symbol."""
@ -159,11 +175,8 @@ def select_shallow_thinking_agent(provider) -> str:
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
],
"ollama": [
("Qwen3:latest (8B, local)", "qwen3:latest"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
],
"ollama": OLLAMA_MODELS,
"azure": AZURE_MODELS,
}
choice = questionary.select(
@ -226,11 +239,8 @@ def select_deep_thinking_agent(provider) -> str:
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
],
"ollama": [
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("Qwen3:latest (8B, local)", "qwen3:latest"),
],
"ollama": OLLAMA_MODELS,
"azure": AZURE_MODELS,
}
choice = questionary.select(
@ -265,6 +275,7 @@ def select_llm_provider() -> tuple[str, str]:
("xAI", "https://api.x.ai/v1"),
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
("Azure OpenAI", "azure"),
]
choice = questionary.select(

View File

@ -0,0 +1,63 @@
"""Azure OpenAI client implementation."""
import os
from typing import Any, Optional
from langchain_openai import AzureChatOpenAI
from .base_client import BaseLLMClient
class AzureClient(BaseLLMClient):
"""Client for Azure OpenAI provider.
Required environment variables:
AZURE_OPENAI_API_KEY: Your Azure OpenAI API key
AZURE_OPENAI_ENDPOINT: Your Azure endpoint (e.g., https://your-resource.openai.azure.com/)
OPENAI_API_VERSION: API version (e.g., 2024-02-15-preview)
Optional:
AZURE_DEPLOYMENT_NAME: Your deployment name (can also be passed as model parameter)
"""
def __init__(
self,
model: str,
base_url: Optional[str] = None,
**kwargs,
):
super().__init__(model, base_url, **kwargs)
# In Azure, model is the deployment name
self.deployment_name = model
def get_llm(self) -> Any:
"""Return configured AzureChatOpenAI instance."""
llm_kwargs = {
"azure_deployment": self.deployment_name,
}
# Pass through Azure-specific kwargs
for key in ("timeout", "max_retries", "callbacks", "api_version",
"azure_endpoint", "api_key", "http_client", "http_async_client"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
# Environment variable fallbacks (DRY pattern)
for kwarg_name, env_var in {
"azure_endpoint": "AZURE_OPENAI_ENDPOINT",
"api_key": "AZURE_OPENAI_API_KEY",
}.items():
if kwarg_name not in llm_kwargs:
value = os.environ.get(env_var)
if value:
llm_kwargs[kwarg_name] = value
if "api_version" not in llm_kwargs:
api_version = os.environ.get("OPENAI_API_VERSION", "2024-02-15-preview")
llm_kwargs["api_version"] = api_version
return AzureChatOpenAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate Azure deployment name (always valid, Azure validates at runtime)."""
return bool(self.deployment_name)

View File

@ -4,6 +4,7 @@ from .base_client import BaseLLMClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
from .google_client import GoogleClient
from .azure_client import AzureClient
def create_llm_client(
@ -15,8 +16,8 @@ def create_llm_client(
"""Create an LLM client for the specified provider.
Args:
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
model: Model name/identifier
provider: LLM provider (openai, anthropic, google, azure, xai, ollama, openrouter)
model: Model name/identifier (for Azure, use deployment name)
base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments
@ -40,4 +41,7 @@ def create_llm_client(
if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs)
if provider_lower == "azure":
return AzureClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@ -54,7 +54,13 @@ class OpenAIClient(BaseLLMClient):
if api_key:
llm_kwargs["api_key"] = api_key
elif self.provider == "ollama":
llm_kwargs["base_url"] = "http://localhost:11434/v1"
# Support custom Ollama URL via base_url or OLLAMA_HOST env var
ollama_url = self.base_url or os.environ.get("OLLAMA_HOST", "http://localhost:11434")
# Ensure /v1 suffix for OpenAI-compatible API (handle trailing slash edge case)
ollama_url = ollama_url.rstrip("/")
if not ollama_url.endswith("/v1"):
ollama_url += "/v1"
llm_kwargs["base_url"] = ollama_url
llm_kwargs["api_key"] = "ollama" # Ollama doesn't require auth
elif self.base_url:
llm_kwargs["base_url"] = self.base_url