fix: address Gemini review feedback for PR #397
- Fix Ollama URL normalization bug (trailing slash edge case)
- Use rstrip('/') before checking /v1 suffix
- Add DRY pattern for Azure env var fallbacks
- Add final newline to azure_client.py
- Use shared OLLAMA_MODELS/AZURE_MODELS constants to avoid duplicates
This commit is contained in:
parent
06859104e8
commit
6bdea25402
31
cli/utils.py
31
cli/utils.py
|
|
@ -14,6 +14,22 @@ ANALYST_ORDER = [
|
||||||
("Fundamentals Analyst", AnalystType.FUNDAMENTALS),
|
("Fundamentals Analyst", AnalystType.FUNDAMENTALS),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Shared model lists for providers with identical options in shallow/deep selections
|
||||||
|
# Azure: deployment names are user-specific, so we provide common defaults
|
||||||
|
AZURE_MODELS = [
|
||||||
|
("GPT-4o (your deployment)", "gpt-4o"),
|
||||||
|
("GPT-4o-mini (your deployment)", "gpt-4o-mini"),
|
||||||
|
("GPT-4 Turbo (your deployment)", "gpt-4-turbo"),
|
||||||
|
("Custom deployment name", "__custom__"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Ollama: same models for both shallow and deep (local inference)
|
||||||
|
OLLAMA_MODELS = [
|
||||||
|
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||||
|
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||||
|
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_ticker() -> str:
|
def get_ticker() -> str:
|
||||||
"""Prompt the user to enter a ticker symbol."""
|
"""Prompt the user to enter a ticker symbol."""
|
||||||
|
|
@ -159,11 +175,8 @@ def select_shallow_thinking_agent(provider) -> str:
|
||||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||||
],
|
],
|
||||||
"ollama": [
|
"ollama": OLLAMA_MODELS,
|
||||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
"azure": AZURE_MODELS,
|
||||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
|
||||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
|
|
@ -226,11 +239,8 @@ def select_deep_thinking_agent(provider) -> str:
|
||||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||||
],
|
],
|
||||||
"ollama": [
|
"ollama": OLLAMA_MODELS,
|
||||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
"azure": AZURE_MODELS,
|
||||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
|
||||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
|
|
@ -265,6 +275,7 @@ def select_llm_provider() -> tuple[str, str]:
|
||||||
("xAI", "https://api.x.ai/v1"),
|
("xAI", "https://api.x.ai/v1"),
|
||||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||||
("Ollama", "http://localhost:11434/v1"),
|
("Ollama", "http://localhost:11434/v1"),
|
||||||
|
("Azure OpenAI", "azure"),
|
||||||
]
|
]
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
"""Azure OpenAI client implementation."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient
|
||||||
|
|
||||||
|
|
||||||
|
class AzureClient(BaseLLMClient):
|
||||||
|
"""Client for Azure OpenAI provider.
|
||||||
|
|
||||||
|
Required environment variables:
|
||||||
|
AZURE_OPENAI_API_KEY: Your Azure OpenAI API key
|
||||||
|
AZURE_OPENAI_ENDPOINT: Your Azure endpoint (e.g., https://your-resource.openai.azure.com/)
|
||||||
|
OPENAI_API_VERSION: API version (e.g., 2024-02-15-preview)
|
||||||
|
|
||||||
|
Optional:
|
||||||
|
AZURE_DEPLOYMENT_NAME: Your deployment name (can also be passed as model parameter)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
# In Azure, model is the deployment name
|
||||||
|
self.deployment_name = model
|
||||||
|
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
"""Return configured AzureChatOpenAI instance."""
|
||||||
|
llm_kwargs = {
|
||||||
|
"azure_deployment": self.deployment_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Pass through Azure-specific kwargs
|
||||||
|
for key in ("timeout", "max_retries", "callbacks", "api_version",
|
||||||
|
"azure_endpoint", "api_key", "http_client", "http_async_client"):
|
||||||
|
if key in self.kwargs:
|
||||||
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
# Environment variable fallbacks (DRY pattern)
|
||||||
|
for kwarg_name, env_var in {
|
||||||
|
"azure_endpoint": "AZURE_OPENAI_ENDPOINT",
|
||||||
|
"api_key": "AZURE_OPENAI_API_KEY",
|
||||||
|
}.items():
|
||||||
|
if kwarg_name not in llm_kwargs:
|
||||||
|
value = os.environ.get(env_var)
|
||||||
|
if value:
|
||||||
|
llm_kwargs[kwarg_name] = value
|
||||||
|
|
||||||
|
if "api_version" not in llm_kwargs:
|
||||||
|
api_version = os.environ.get("OPENAI_API_VERSION", "2024-02-15-preview")
|
||||||
|
llm_kwargs["api_version"] = api_version
|
||||||
|
|
||||||
|
return AzureChatOpenAI(**llm_kwargs)
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
"""Validate Azure deployment name (always valid, Azure validates at runtime)."""
|
||||||
|
return bool(self.deployment_name)
|
||||||
|
|
@ -4,6 +4,7 @@ from .base_client import BaseLLMClient
|
||||||
from .openai_client import OpenAIClient
|
from .openai_client import OpenAIClient
|
||||||
from .anthropic_client import AnthropicClient
|
from .anthropic_client import AnthropicClient
|
||||||
from .google_client import GoogleClient
|
from .google_client import GoogleClient
|
||||||
|
from .azure_client import AzureClient
|
||||||
|
|
||||||
|
|
||||||
def create_llm_client(
|
def create_llm_client(
|
||||||
|
|
@ -15,8 +16,8 @@ def create_llm_client(
|
||||||
"""Create an LLM client for the specified provider.
|
"""Create an LLM client for the specified provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
provider: LLM provider (openai, anthropic, google, azure, xai, ollama, openrouter)
|
||||||
model: Model name/identifier
|
model: Model name/identifier (for Azure, use deployment name)
|
||||||
base_url: Optional base URL for API endpoint
|
base_url: Optional base URL for API endpoint
|
||||||
**kwargs: Additional provider-specific arguments
|
**kwargs: Additional provider-specific arguments
|
||||||
|
|
||||||
|
|
@ -40,4 +41,7 @@ def create_llm_client(
|
||||||
if provider_lower == "google":
|
if provider_lower == "google":
|
||||||
return GoogleClient(model, base_url, **kwargs)
|
return GoogleClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
if provider_lower == "azure":
|
||||||
|
return AzureClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,13 @@ class OpenAIClient(BaseLLMClient):
|
||||||
if api_key:
|
if api_key:
|
||||||
llm_kwargs["api_key"] = api_key
|
llm_kwargs["api_key"] = api_key
|
||||||
elif self.provider == "ollama":
|
elif self.provider == "ollama":
|
||||||
llm_kwargs["base_url"] = "http://localhost:11434/v1"
|
# Support custom Ollama URL via base_url or OLLAMA_HOST env var
|
||||||
|
ollama_url = self.base_url or os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
||||||
|
# Ensure /v1 suffix for OpenAI-compatible API (handle trailing slash edge case)
|
||||||
|
ollama_url = ollama_url.rstrip("/")
|
||||||
|
if not ollama_url.endswith("/v1"):
|
||||||
|
ollama_url += "/v1"
|
||||||
|
llm_kwargs["base_url"] = ollama_url
|
||||||
llm_kwargs["api_key"] = "ollama" # Ollama doesn't require auth
|
llm_kwargs["api_key"] = "ollama" # Ollama doesn't require auth
|
||||||
elif self.base_url:
|
elif self.base_url:
|
||||||
llm_kwargs["base_url"] = self.base_url
|
llm_kwargs["base_url"] = self.base_url
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue