refactor(llm_clients): extract config loading logic into separate module
Move config loading and validation functions from openai_client.py and factory.py into a new shared config_loader.py module. This centralizes configuration handling, reduces code duplication, and improves maintainability. The factory now gracefully falls back to default provider types if config loading fails.
This commit is contained in:
parent
244a986e83
commit
d1818de073
|
|
@ -0,0 +1,41 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json"
|
||||
|
||||
|
||||
def load_config() -> dict:
|
||||
try:
|
||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
||||
config = json.load(config_file)
|
||||
except FileNotFoundError as exc:
|
||||
raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc
|
||||
except OSError as exc:
|
||||
raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc
|
||||
if not isinstance(config, dict):
|
||||
raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}")
|
||||
return config
|
||||
|
||||
|
||||
def get_config_section(config: dict, key: str, expected_type: type) -> Any:
|
||||
value = config.get(key)
|
||||
if not isinstance(value, expected_type):
|
||||
raise RuntimeError(f"Invalid or missing '{key}' in config file: {CONFIG_PATH}")
|
||||
return value
|
||||
|
||||
|
||||
def get_base_urls_map(config: dict) -> dict[str, str]:
|
||||
base_urls = get_config_section(config, "BASE_URLS", list)
|
||||
mapped_urls: dict[str, str] = {}
|
||||
for item in base_urls:
|
||||
if (
|
||||
isinstance(item, list)
|
||||
and len(item) == 2
|
||||
and isinstance(item[0], str)
|
||||
and isinstance(item[1], str)
|
||||
):
|
||||
mapped_urls[item[0].lower()] = item[1]
|
||||
return mapped_urls
|
||||
|
|
@ -1,40 +1,39 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
from .openai_client import OpenAIClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
from .google_client import GoogleClient
|
||||
from .config_loader import load_config, get_config_section
|
||||
|
||||
CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
try:
|
||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
||||
config = json.load(config_file)
|
||||
except FileNotFoundError as exc:
|
||||
raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc
|
||||
except OSError as exc:
|
||||
raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc
|
||||
if not isinstance(config, dict):
|
||||
raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}")
|
||||
return config
|
||||
DEFAULT_PROVIDER_TYPES: dict[str, str] = {
|
||||
"openai": "openai",
|
||||
"anthropic": "anthropic",
|
||||
"google": "google",
|
||||
"xai": "openai",
|
||||
"openrouter": "openai",
|
||||
"ollama": "openai",
|
||||
"lmstudio": "openai",
|
||||
}
|
||||
|
||||
|
||||
def _load_provider_types() -> dict[str, str]:
|
||||
provider_types = _load_config().get("LLM_PROVIDER_TYPES")
|
||||
if not isinstance(provider_types, dict):
|
||||
raise RuntimeError(
|
||||
f"Invalid or missing 'LLM_PROVIDER_TYPES' in config file: {CONFIG_PATH}"
|
||||
try:
|
||||
config = load_config()
|
||||
provider_types = get_config_section(config, "LLM_PROVIDER_TYPES", dict)
|
||||
return {
|
||||
str(name).lower(): str(client_type).lower()
|
||||
for name, client_type in provider_types.items()
|
||||
}
|
||||
except RuntimeError as exc:
|
||||
warnings.warn(
|
||||
f"Failed to load LLM_PROVIDER_TYPES from config.json: {exc}. "
|
||||
"Using built-in provider mapping.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return {
|
||||
str(name).lower(): str(client_type).lower()
|
||||
for name, client_type in provider_types.items()
|
||||
}
|
||||
return DEFAULT_PROVIDER_TYPES.copy()
|
||||
|
||||
|
||||
_PROVIDER_TYPES: dict[str, str] | None = None
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from .base_client import BaseLLMClient, normalize_content
|
||||
from .config_loader import load_config, get_base_urls_map
|
||||
from .validators import validate_model
|
||||
|
||||
|
||||
|
|
@ -27,47 +26,11 @@ _PASSTHROUGH_KWARGS = (
|
|||
"api_key", "callbacks", "http_client", "http_async_client",
|
||||
)
|
||||
|
||||
CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json"
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
try:
|
||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
||||
config = json.load(config_file)
|
||||
except FileNotFoundError as exc:
|
||||
raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc
|
||||
except OSError as exc:
|
||||
raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc
|
||||
|
||||
if not isinstance(config, dict):
|
||||
raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}")
|
||||
return config
|
||||
|
||||
|
||||
def _get_base_urls(config: dict) -> dict[str, str]:
|
||||
base_urls = config.get("BASE_URLS")
|
||||
if not isinstance(base_urls, list):
|
||||
raise RuntimeError(f"Invalid or missing 'BASE_URLS' in config file: {CONFIG_PATH}")
|
||||
|
||||
mapped_urls: dict[str, str] = {}
|
||||
for item in base_urls:
|
||||
if (
|
||||
isinstance(item, list)
|
||||
and len(item) == 2
|
||||
and isinstance(item[0], str)
|
||||
and isinstance(item[1], str)
|
||||
):
|
||||
mapped_urls[item[0].lower()] = item[1]
|
||||
return mapped_urls
|
||||
|
||||
|
||||
CONFIG = _load_config()
|
||||
CONFIG = load_config()
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_BASE_URLS = _get_base_urls(CONFIG)
|
||||
_BASE_URLS = get_base_urls_map(CONFIG)
|
||||
_PROVIDER_BASE_URLS = {
|
||||
"xai": _BASE_URLS.get("xai", "https://api.x.ai/v1"),
|
||||
"openrouter": _BASE_URLS.get("openrouter", "https://openrouter.ai/api/v1"),
|
||||
|
|
|
|||
Loading…
Reference in New Issue