refactor(config): improve config loading with validation and error handling

- Replace inline config loading with dedicated functions `_load_config` and `_get_section` (or `_get_base_urls`).
- Add explicit error handling for missing files, invalid JSON, and I/O errors.
- Validate config structure and section types to provide clear runtime errors.
- Centralize logic to avoid code duplication across modules.
This commit is contained in:
Maytekin 2026-03-24 16:29:19 +00:00
parent 065d033faf
commit 4261c62e7c
3 changed files with 109 additions and 18 deletions

View File

@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Dict from typing import List
import questionary import questionary
from rich.console import Console from rich.console import Console
@ -10,17 +10,50 @@ from cli.models import AnalystType
console = Console() console = Console()
CONFIG_PATH = Path(__file__).resolve().parents[1] / "config.json" CONFIG_PATH = Path(__file__).resolve().parents[1] / "config.json"
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
CONFIG = json.load(config_file)
BASE_URLS = [(display, url) for display, url in CONFIG["BASE_URLS"]] def _exit_with_config_error(message: str) -> None:
console.print(f"\n[red]{message}[/red]")
exit(1)
def _load_config() -> dict:
try:
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
return json.load(config_file)
except FileNotFoundError:
_exit_with_config_error(f"Config file not found: {CONFIG_PATH}")
except json.JSONDecodeError as exc:
_exit_with_config_error(f"Invalid JSON in config file: {exc}")
except OSError as exc:
_exit_with_config_error(f"Unable to read config file: {exc}")
return {}
def _get_config_section(config: dict, key: str, expected_type: type):
value = config.get(key)
if not isinstance(value, expected_type):
_exit_with_config_error(
f"Invalid or missing '{key}' in config file: {CONFIG_PATH}"
)
return value
CONFIG = _load_config()
BASE_URLS = [
(display, url)
for display, url in _get_config_section(CONFIG, "BASE_URLS", list)
]
DEEP_AGENT_OPTIONS = { DEEP_AGENT_OPTIONS = {
provider: [(display, value) for display, value in options] provider: [(display, model) for display, model in options]
for provider, options in CONFIG["DEEP_AGENT_OPTIONS"].items() for provider, options in _get_config_section(
CONFIG, "DEEP_AGENT_OPTIONS", dict
).items()
} }
SHALLOW_AGENT_OPTIONS = { SHALLOW_AGENT_OPTIONS = {
provider: [(display, value) for display, value in options] provider: [(display, model) for display, model in options]
for provider, options in CONFIG["SHALLOW_AGENT_OPTIONS"].items() for provider, options in _get_config_section(
CONFIG, "SHALLOW_AGENT_OPTIONS", dict
).items()
} }
TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK" TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK"

View File

@ -3,11 +3,37 @@ import os
from pathlib import Path from pathlib import Path
CONFIG_PATH = Path(__file__).resolve().parents[1] / "config.json" CONFIG_PATH = Path(__file__).resolve().parents[1] / "config.json"
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
CONFIG = json.load(config_file)
DEFAULT_LLM_SETTINGS = CONFIG.get("DEFAULT_LLM_SETTINGS", {})
BASE_URLS = {display.lower(): url for display, url in CONFIG.get("BASE_URLS", [])} def _load_config() -> dict:
try:
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
config = json.load(config_file)
except FileNotFoundError as exc:
raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc
except json.JSONDecodeError as exc:
raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc
except OSError as exc:
raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc
if not isinstance(config, dict):
raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}")
return config
def _get_section(config: dict, key: str, expected_type: type):
section = config.get(key)
if not isinstance(section, expected_type):
raise RuntimeError(f"Invalid or missing '{key}' in config file: {CONFIG_PATH}")
return section
CONFIG = _load_config()
DEFAULT_LLM_SETTINGS = _get_section(CONFIG, "DEFAULT_LLM_SETTINGS", dict)
BASE_URLS = {
display.lower(): url
for display, url in _get_section(CONFIG, "BASE_URLS", list)
}
DEFAULT_PROVIDER = DEFAULT_LLM_SETTINGS.get("llm_provider", "openai").lower() DEFAULT_PROVIDER = DEFAULT_LLM_SETTINGS.get("llm_provider", "openai").lower()
DEFAULT_BACKEND_URL = BASE_URLS.get( DEFAULT_BACKEND_URL = BASE_URLS.get(
DEFAULT_PROVIDER, BASE_URLS.get("openai", "https://api.openai.com/v1") DEFAULT_PROVIDER, BASE_URLS.get("openai", "https://api.openai.com/v1")

View File

@ -28,14 +28,46 @@ _PASSTHROUGH_KWARGS = (
) )
CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json" CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json"
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
CONFIG = json.load(config_file)
def _load_config() -> dict:
try:
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
config = json.load(config_file)
except FileNotFoundError as exc:
raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc
except json.JSONDecodeError as exc:
raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc
except OSError as exc:
raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc
if not isinstance(config, dict):
raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}")
return config
def _get_base_urls(config: dict) -> dict[str, str]:
base_urls = config.get("BASE_URLS")
if not isinstance(base_urls, list):
raise RuntimeError(f"Invalid or missing 'BASE_URLS' in config file: {CONFIG_PATH}")
mapped_urls: dict[str, str] = {}
for item in base_urls:
if (
isinstance(item, list)
and len(item) == 2
and isinstance(item[0], str)
and isinstance(item[1], str)
):
mapped_urls[item[0].lower()] = item[1]
return mapped_urls
CONFIG = _load_config()
load_dotenv() load_dotenv()
_BASE_URLS = { _BASE_URLS = _get_base_urls(CONFIG)
display.lower(): url for display, url in CONFIG.get("BASE_URLS", [])
}
_PROVIDER_BASE_URLS = { _PROVIDER_BASE_URLS = {
"xai": _BASE_URLS.get("xai", "https://api.x.ai/v1"), "xai": _BASE_URLS.get("xai", "https://api.x.ai/v1"),
"openrouter": _BASE_URLS.get("openrouter", "https://openrouter.ai/api/v1"), "openrouter": _BASE_URLS.get("openrouter", "https://openrouter.ai/api/v1"),
@ -68,7 +100,7 @@ class OpenAIClient(BaseLLMClient):
def get_llm(self) -> Any: def get_llm(self) -> Any:
"""Return configured ChatOpenAI instance.""" """Return configured ChatOpenAI instance."""
llm_kwargs = {"model": self.model} llm_kwargs: dict[str, Any] = {"model": self.model}
# Provider-specific base URL and auth # Provider-specific base URL and auth
if self.provider in _PROVIDER_BASE_URLS: if self.provider in _PROVIDER_BASE_URLS: