refactor(llm_clients): extract config loading logic into separate module
Move config loading and validation functions from openai_client.py and factory.py into a new shared config_loader.py module. This centralizes configuration handling, reduces code duplication, and improves maintainability. The factory now gracefully falls back to default provider types if config loading fails.
This commit is contained in:
parent
244a986e83
commit
d1818de073
|
|
@ -0,0 +1,41 @@
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json"
|
||||||
|
|
||||||
|
|
||||||
|
def load_config() -> dict:
|
||||||
|
try:
|
||||||
|
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
||||||
|
config = json.load(config_file)
|
||||||
|
except FileNotFoundError as exc:
|
||||||
|
raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc
|
||||||
|
except OSError as exc:
|
||||||
|
raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}")
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_section(config: dict, key: str, expected_type: type) -> Any:
|
||||||
|
value = config.get(key)
|
||||||
|
if not isinstance(value, expected_type):
|
||||||
|
raise RuntimeError(f"Invalid or missing '{key}' in config file: {CONFIG_PATH}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_urls_map(config: dict) -> dict[str, str]:
|
||||||
|
base_urls = get_config_section(config, "BASE_URLS", list)
|
||||||
|
mapped_urls: dict[str, str] = {}
|
||||||
|
for item in base_urls:
|
||||||
|
if (
|
||||||
|
isinstance(item, list)
|
||||||
|
and len(item) == 2
|
||||||
|
and isinstance(item[0], str)
|
||||||
|
and isinstance(item[1], str)
|
||||||
|
):
|
||||||
|
mapped_urls[item[0].lower()] = item[1]
|
||||||
|
return mapped_urls
|
||||||
|
|
@ -1,40 +1,39 @@
|
||||||
import json
|
import warnings
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .base_client import BaseLLMClient
|
from .base_client import BaseLLMClient
|
||||||
from .openai_client import OpenAIClient
|
from .openai_client import OpenAIClient
|
||||||
from .anthropic_client import AnthropicClient
|
from .anthropic_client import AnthropicClient
|
||||||
from .google_client import GoogleClient
|
from .google_client import GoogleClient
|
||||||
|
from .config_loader import load_config, get_config_section
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json"
|
DEFAULT_PROVIDER_TYPES: dict[str, str] = {
|
||||||
|
"openai": "openai",
|
||||||
|
"anthropic": "anthropic",
|
||||||
def _load_config() -> dict:
|
"google": "google",
|
||||||
try:
|
"xai": "openai",
|
||||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
"openrouter": "openai",
|
||||||
config = json.load(config_file)
|
"ollama": "openai",
|
||||||
except FileNotFoundError as exc:
|
"lmstudio": "openai",
|
||||||
raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc
|
}
|
||||||
except json.JSONDecodeError as exc:
|
|
||||||
raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc
|
|
||||||
except OSError as exc:
|
|
||||||
raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc
|
|
||||||
if not isinstance(config, dict):
|
|
||||||
raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}")
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def _load_provider_types() -> dict[str, str]:
|
def _load_provider_types() -> dict[str, str]:
|
||||||
provider_types = _load_config().get("LLM_PROVIDER_TYPES")
|
try:
|
||||||
if not isinstance(provider_types, dict):
|
config = load_config()
|
||||||
raise RuntimeError(
|
provider_types = get_config_section(config, "LLM_PROVIDER_TYPES", dict)
|
||||||
f"Invalid or missing 'LLM_PROVIDER_TYPES' in config file: {CONFIG_PATH}"
|
return {
|
||||||
|
str(name).lower(): str(client_type).lower()
|
||||||
|
for name, client_type in provider_types.items()
|
||||||
|
}
|
||||||
|
except RuntimeError as exc:
|
||||||
|
warnings.warn(
|
||||||
|
f"Failed to load LLM_PROVIDER_TYPES from config.json: {exc}. "
|
||||||
|
"Using built-in provider mapping.",
|
||||||
|
RuntimeWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
return {
|
return DEFAULT_PROVIDER_TYPES.copy()
|
||||||
str(name).lower(): str(client_type).lower()
|
|
||||||
for name, client_type in provider_types.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
_PROVIDER_TYPES: dict[str, str] | None = None
|
_PROVIDER_TYPES: dict[str, str] | None = None
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
from .base_client import BaseLLMClient, normalize_content
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
|
from .config_loader import load_config, get_base_urls_map
|
||||||
from .validators import validate_model
|
from .validators import validate_model
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -27,47 +26,11 @@ _PASSTHROUGH_KWARGS = (
|
||||||
"api_key", "callbacks", "http_client", "http_async_client",
|
"api_key", "callbacks", "http_client", "http_async_client",
|
||||||
)
|
)
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json"
|
CONFIG = load_config()
|
||||||
|
|
||||||
|
|
||||||
def _load_config() -> dict:
|
|
||||||
try:
|
|
||||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
|
||||||
config = json.load(config_file)
|
|
||||||
except FileNotFoundError as exc:
|
|
||||||
raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc
|
|
||||||
except json.JSONDecodeError as exc:
|
|
||||||
raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc
|
|
||||||
except OSError as exc:
|
|
||||||
raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc
|
|
||||||
|
|
||||||
if not isinstance(config, dict):
|
|
||||||
raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}")
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def _get_base_urls(config: dict) -> dict[str, str]:
|
|
||||||
base_urls = config.get("BASE_URLS")
|
|
||||||
if not isinstance(base_urls, list):
|
|
||||||
raise RuntimeError(f"Invalid or missing 'BASE_URLS' in config file: {CONFIG_PATH}")
|
|
||||||
|
|
||||||
mapped_urls: dict[str, str] = {}
|
|
||||||
for item in base_urls:
|
|
||||||
if (
|
|
||||||
isinstance(item, list)
|
|
||||||
and len(item) == 2
|
|
||||||
and isinstance(item[0], str)
|
|
||||||
and isinstance(item[1], str)
|
|
||||||
):
|
|
||||||
mapped_urls[item[0].lower()] = item[1]
|
|
||||||
return mapped_urls
|
|
||||||
|
|
||||||
|
|
||||||
CONFIG = _load_config()
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
_BASE_URLS = _get_base_urls(CONFIG)
|
_BASE_URLS = get_base_urls_map(CONFIG)
|
||||||
_PROVIDER_BASE_URLS = {
|
_PROVIDER_BASE_URLS = {
|
||||||
"xai": _BASE_URLS.get("xai", "https://api.x.ai/v1"),
|
"xai": _BASE_URLS.get("xai", "https://api.x.ai/v1"),
|
||||||
"openrouter": _BASE_URLS.get("openrouter", "https://openrouter.ai/api/v1"),
|
"openrouter": _BASE_URLS.get("openrouter", "https://openrouter.ai/api/v1"),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue