refactor(config): improve config loading with validation and error handling
- Replace inline config loading with dedicated functions `_load_config` and `_get_section` (or `_get_base_urls`). - Add explicit error handling for missing files, invalid JSON, and I/O errors. - Validate config structure and section types to provide clear runtime errors. - Centralize logic to avoid code duplication across modules.
This commit is contained in:
parent
065d033faf
commit
4261c62e7c
49
cli/utils.py
49
cli/utils.py
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Dict
|
from typing import List
|
||||||
|
|
||||||
import questionary
|
import questionary
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
@ -10,17 +10,50 @@ from cli.models import AnalystType
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).resolve().parents[1] / "config.json"
|
CONFIG_PATH = Path(__file__).resolve().parents[1] / "config.json"
|
||||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
|
||||||
CONFIG = json.load(config_file)
|
|
||||||
|
|
||||||
BASE_URLS = [(display, url) for display, url in CONFIG["BASE_URLS"]]
|
def _exit_with_config_error(message: str) -> None:
|
||||||
|
console.print(f"\n[red]{message}[/red]")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_config() -> dict:
|
||||||
|
try:
|
||||||
|
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
||||||
|
return json.load(config_file)
|
||||||
|
except FileNotFoundError:
|
||||||
|
_exit_with_config_error(f"Config file not found: {CONFIG_PATH}")
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
_exit_with_config_error(f"Invalid JSON in config file: {exc}")
|
||||||
|
except OSError as exc:
|
||||||
|
_exit_with_config_error(f"Unable to read config file: {exc}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_config_section(config: dict, key: str, expected_type: type):
|
||||||
|
value = config.get(key)
|
||||||
|
if not isinstance(value, expected_type):
|
||||||
|
_exit_with_config_error(
|
||||||
|
f"Invalid or missing '{key}' in config file: {CONFIG_PATH}"
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG = _load_config()
|
||||||
|
BASE_URLS = [
|
||||||
|
(display, url)
|
||||||
|
for display, url in _get_config_section(CONFIG, "BASE_URLS", list)
|
||||||
|
]
|
||||||
DEEP_AGENT_OPTIONS = {
|
DEEP_AGENT_OPTIONS = {
|
||||||
provider: [(display, value) for display, value in options]
|
provider: [(display, model) for display, model in options]
|
||||||
for provider, options in CONFIG["DEEP_AGENT_OPTIONS"].items()
|
for provider, options in _get_config_section(
|
||||||
|
CONFIG, "DEEP_AGENT_OPTIONS", dict
|
||||||
|
).items()
|
||||||
}
|
}
|
||||||
SHALLOW_AGENT_OPTIONS = {
|
SHALLOW_AGENT_OPTIONS = {
|
||||||
provider: [(display, value) for display, value in options]
|
provider: [(display, model) for display, model in options]
|
||||||
for provider, options in CONFIG["SHALLOW_AGENT_OPTIONS"].items()
|
for provider, options in _get_config_section(
|
||||||
|
CONFIG, "SHALLOW_AGENT_OPTIONS", dict
|
||||||
|
).items()
|
||||||
}
|
}
|
||||||
|
|
||||||
TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK"
|
TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK"
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,37 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).resolve().parents[1] / "config.json"
|
CONFIG_PATH = Path(__file__).resolve().parents[1] / "config.json"
|
||||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
|
||||||
CONFIG = json.load(config_file)
|
|
||||||
|
|
||||||
DEFAULT_LLM_SETTINGS = CONFIG.get("DEFAULT_LLM_SETTINGS", {})
|
|
||||||
BASE_URLS = {display.lower(): url for display, url in CONFIG.get("BASE_URLS", [])}
|
def _load_config() -> dict:
|
||||||
|
try:
|
||||||
|
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
||||||
|
config = json.load(config_file)
|
||||||
|
except FileNotFoundError as exc:
|
||||||
|
raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc
|
||||||
|
except OSError as exc:
|
||||||
|
raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc
|
||||||
|
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}")
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _get_section(config: dict, key: str, expected_type: type):
|
||||||
|
section = config.get(key)
|
||||||
|
if not isinstance(section, expected_type):
|
||||||
|
raise RuntimeError(f"Invalid or missing '{key}' in config file: {CONFIG_PATH}")
|
||||||
|
return section
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG = _load_config()
|
||||||
|
DEFAULT_LLM_SETTINGS = _get_section(CONFIG, "DEFAULT_LLM_SETTINGS", dict)
|
||||||
|
BASE_URLS = {
|
||||||
|
display.lower(): url
|
||||||
|
for display, url in _get_section(CONFIG, "BASE_URLS", list)
|
||||||
|
}
|
||||||
DEFAULT_PROVIDER = DEFAULT_LLM_SETTINGS.get("llm_provider", "openai").lower()
|
DEFAULT_PROVIDER = DEFAULT_LLM_SETTINGS.get("llm_provider", "openai").lower()
|
||||||
DEFAULT_BACKEND_URL = BASE_URLS.get(
|
DEFAULT_BACKEND_URL = BASE_URLS.get(
|
||||||
DEFAULT_PROVIDER, BASE_URLS.get("openai", "https://api.openai.com/v1")
|
DEFAULT_PROVIDER, BASE_URLS.get("openai", "https://api.openai.com/v1")
|
||||||
|
|
|
||||||
|
|
@ -28,14 +28,46 @@ _PASSTHROUGH_KWARGS = (
|
||||||
)
|
)
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json"
|
CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json"
|
||||||
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
|
||||||
CONFIG = json.load(config_file)
|
|
||||||
|
def _load_config() -> dict:
|
||||||
|
try:
|
||||||
|
with CONFIG_PATH.open("r", encoding="utf-8") as config_file:
|
||||||
|
config = json.load(config_file)
|
||||||
|
except FileNotFoundError as exc:
|
||||||
|
raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc
|
||||||
|
except OSError as exc:
|
||||||
|
raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc
|
||||||
|
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}")
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _get_base_urls(config: dict) -> dict[str, str]:
|
||||||
|
base_urls = config.get("BASE_URLS")
|
||||||
|
if not isinstance(base_urls, list):
|
||||||
|
raise RuntimeError(f"Invalid or missing 'BASE_URLS' in config file: {CONFIG_PATH}")
|
||||||
|
|
||||||
|
mapped_urls: dict[str, str] = {}
|
||||||
|
for item in base_urls:
|
||||||
|
if (
|
||||||
|
isinstance(item, list)
|
||||||
|
and len(item) == 2
|
||||||
|
and isinstance(item[0], str)
|
||||||
|
and isinstance(item[1], str)
|
||||||
|
):
|
||||||
|
mapped_urls[item[0].lower()] = item[1]
|
||||||
|
return mapped_urls
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG = _load_config()
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
_BASE_URLS = {
|
_BASE_URLS = _get_base_urls(CONFIG)
|
||||||
display.lower(): url for display, url in CONFIG.get("BASE_URLS", [])
|
|
||||||
}
|
|
||||||
_PROVIDER_BASE_URLS = {
|
_PROVIDER_BASE_URLS = {
|
||||||
"xai": _BASE_URLS.get("xai", "https://api.x.ai/v1"),
|
"xai": _BASE_URLS.get("xai", "https://api.x.ai/v1"),
|
||||||
"openrouter": _BASE_URLS.get("openrouter", "https://openrouter.ai/api/v1"),
|
"openrouter": _BASE_URLS.get("openrouter", "https://openrouter.ai/api/v1"),
|
||||||
|
|
@ -68,7 +100,7 @@ class OpenAIClient(BaseLLMClient):
|
||||||
|
|
||||||
def get_llm(self) -> Any:
|
def get_llm(self) -> Any:
|
||||||
"""Return configured ChatOpenAI instance."""
|
"""Return configured ChatOpenAI instance."""
|
||||||
llm_kwargs = {"model": self.model}
|
llm_kwargs: dict[str, Any] = {"model": self.model}
|
||||||
|
|
||||||
# Provider-specific base URL and auth
|
# Provider-specific base URL and auth
|
||||||
if self.provider in _PROVIDER_BASE_URLS:
|
if self.provider in _PROVIDER_BASE_URLS:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue