refactor(config): improve config loading with validation and error handling
- Replace inline config loading with dedicated functions `_load_config` and `_get_section` (or `_get_base_urls`). - Add explicit error handling for missing files, invalid JSON, and I/O errors. - Validate config structure and section types to provide clear runtime errors. - Centralize logic to avoid code duplication across modules.
This commit is contained in:
parent
065d033faf
commit
4261c62e7c
49
cli/utils.py
49
cli/utils.py
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
from typing import List
|
||||
|
||||
import questionary
|
||||
from rich.console import Console
|
||||
|
|
@ -10,17 +10,50 @@ from cli.models import AnalystType
|
|||
console = Console()
|
||||
|
||||
CONFIG_PATH = Path(__file__).resolve().parents[1] / "config.json"
|
||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
||||
CONFIG = json.load(config_file)
|
||||
|
||||
BASE_URLS = [(display, url) for display, url in CONFIG["BASE_URLS"]]
|
||||
def _exit_with_config_error(message: str) -> None:
|
||||
console.print(f"\n[red]{message}[/red]")
|
||||
exit(1)
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
try:
|
||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
||||
return json.load(config_file)
|
||||
except FileNotFoundError:
|
||||
_exit_with_config_error(f"Config file not found: {CONFIG_PATH}")
|
||||
except json.JSONDecodeError as exc:
|
||||
_exit_with_config_error(f"Invalid JSON in config file: {exc}")
|
||||
except OSError as exc:
|
||||
_exit_with_config_error(f"Unable to read config file: {exc}")
|
||||
return {}
|
||||
|
||||
|
||||
def _get_config_section(config: dict, key: str, expected_type: type):
|
||||
value = config.get(key)
|
||||
if not isinstance(value, expected_type):
|
||||
_exit_with_config_error(
|
||||
f"Invalid or missing '{key}' in config file: {CONFIG_PATH}"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
CONFIG = _load_config()
|
||||
BASE_URLS = [
|
||||
(display, url)
|
||||
for display, url in _get_config_section(CONFIG, "BASE_URLS", list)
|
||||
]
|
||||
DEEP_AGENT_OPTIONS = {
|
||||
provider: [(display, value) for display, value in options]
|
||||
for provider, options in CONFIG["DEEP_AGENT_OPTIONS"].items()
|
||||
provider: [(display, model) for display, model in options]
|
||||
for provider, options in _get_config_section(
|
||||
CONFIG, "DEEP_AGENT_OPTIONS", dict
|
||||
).items()
|
||||
}
|
||||
SHALLOW_AGENT_OPTIONS = {
|
||||
provider: [(display, value) for display, value in options]
|
||||
for provider, options in CONFIG["SHALLOW_AGENT_OPTIONS"].items()
|
||||
provider: [(display, model) for display, model in options]
|
||||
for provider, options in _get_config_section(
|
||||
CONFIG, "SHALLOW_AGENT_OPTIONS", dict
|
||||
).items()
|
||||
}
|
||||
|
||||
TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK"
|
||||
|
|
|
|||
|
|
@ -3,11 +3,37 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
CONFIG_PATH = Path(__file__).resolve().parents[1] / "config.json"
|
||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
||||
CONFIG = json.load(config_file)
|
||||
|
||||
DEFAULT_LLM_SETTINGS = CONFIG.get("DEFAULT_LLM_SETTINGS", {})
|
||||
BASE_URLS = {display.lower(): url for display, url in CONFIG.get("BASE_URLS", [])}
|
||||
|
||||
def _load_config() -> dict:
|
||||
try:
|
||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
||||
config = json.load(config_file)
|
||||
except FileNotFoundError as exc:
|
||||
raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc
|
||||
except OSError as exc:
|
||||
raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc
|
||||
|
||||
if not isinstance(config, dict):
|
||||
raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}")
|
||||
return config
|
||||
|
||||
|
||||
def _get_section(config: dict, key: str, expected_type: type):
|
||||
section = config.get(key)
|
||||
if not isinstance(section, expected_type):
|
||||
raise RuntimeError(f"Invalid or missing '{key}' in config file: {CONFIG_PATH}")
|
||||
return section
|
||||
|
||||
|
||||
CONFIG = _load_config()
|
||||
DEFAULT_LLM_SETTINGS = _get_section(CONFIG, "DEFAULT_LLM_SETTINGS", dict)
|
||||
BASE_URLS = {
|
||||
display.lower(): url
|
||||
for display, url in _get_section(CONFIG, "BASE_URLS", list)
|
||||
}
|
||||
DEFAULT_PROVIDER = DEFAULT_LLM_SETTINGS.get("llm_provider", "openai").lower()
|
||||
DEFAULT_BACKEND_URL = BASE_URLS.get(
|
||||
DEFAULT_PROVIDER, BASE_URLS.get("openai", "https://api.openai.com/v1")
|
||||
|
|
|
|||
|
|
@ -28,14 +28,46 @@ _PASSTHROUGH_KWARGS = (
|
|||
)
|
||||
|
||||
CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json"
|
||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
||||
CONFIG = json.load(config_file)
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
try:
|
||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
||||
config = json.load(config_file)
|
||||
except FileNotFoundError as exc:
|
||||
raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc
|
||||
except OSError as exc:
|
||||
raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc
|
||||
|
||||
if not isinstance(config, dict):
|
||||
raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}")
|
||||
return config
|
||||
|
||||
|
||||
def _get_base_urls(config: dict) -> dict[str, str]:
|
||||
base_urls = config.get("BASE_URLS")
|
||||
if not isinstance(base_urls, list):
|
||||
raise RuntimeError(f"Invalid or missing 'BASE_URLS' in config file: {CONFIG_PATH}")
|
||||
|
||||
mapped_urls: dict[str, str] = {}
|
||||
for item in base_urls:
|
||||
if (
|
||||
isinstance(item, list)
|
||||
and len(item) == 2
|
||||
and isinstance(item[0], str)
|
||||
and isinstance(item[1], str)
|
||||
):
|
||||
mapped_urls[item[0].lower()] = item[1]
|
||||
return mapped_urls
|
||||
|
||||
|
||||
CONFIG = _load_config()
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_BASE_URLS = {
|
||||
display.lower(): url for display, url in CONFIG.get("BASE_URLS", [])
|
||||
}
|
||||
_BASE_URLS = _get_base_urls(CONFIG)
|
||||
_PROVIDER_BASE_URLS = {
|
||||
"xai": _BASE_URLS.get("xai", "https://api.x.ai/v1"),
|
||||
"openrouter": _BASE_URLS.get("openrouter", "https://openrouter.ai/api/v1"),
|
||||
|
|
@ -68,7 +100,7 @@ class OpenAIClient(BaseLLMClient):
|
|||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatOpenAI instance."""
|
||||
llm_kwargs = {"model": self.model}
|
||||
llm_kwargs: dict[str, Any] = {"model": self.model}
|
||||
|
||||
# Provider-specific base URL and auth
|
||||
if self.provider in _PROVIDER_BASE_URLS:
|
||||
|
|
|
|||
Loading…
Reference in New Issue