diff --git a/cli/utils.py b/cli/utils.py index 294a8ba6..9b54ffba 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import List, Optional, Tuple, Dict +from typing import List import questionary from rich.console import Console @@ -10,17 +10,50 @@ from cli.models import AnalystType console = Console() CONFIG_PATH = Path(__file__).resolve().parents[1] / "config.json" -with CONFIG_PATH.open("r", encoding="utf-8") as config_file: - CONFIG = json.load(config_file) -BASE_URLS = [(display, url) for display, url in CONFIG["BASE_URLS"]] +def _exit_with_config_error(message: str) -> None: + console.print(f"\n[red]{message}[/red]") + exit(1) + + +def _load_config() -> dict: + try: + with CONFIG_PATH.open("r", encoding="utf-8") as config_file: + return json.load(config_file) + except FileNotFoundError: + _exit_with_config_error(f"Config file not found: {CONFIG_PATH}") + except json.JSONDecodeError as exc: + _exit_with_config_error(f"Invalid JSON in config file: {exc}") + except OSError as exc: + _exit_with_config_error(f"Unable to read config file: {exc}") + return {} + + +def _get_config_section(config: dict, key: str, expected_type: type): + value = config.get(key) + if not isinstance(value, expected_type): + _exit_with_config_error( + f"Invalid or missing '{key}' in config file: {CONFIG_PATH}" + ) + return value + + +CONFIG = _load_config() +BASE_URLS = [ + (display, url) + for display, url in _get_config_section(CONFIG, "BASE_URLS", list) +] DEEP_AGENT_OPTIONS = { - provider: [(display, value) for display, value in options] - for provider, options in CONFIG["DEEP_AGENT_OPTIONS"].items() + provider: [(display, model) for display, model in options] + for provider, options in _get_config_section( + CONFIG, "DEEP_AGENT_OPTIONS", dict + ).items() } SHALLOW_AGENT_OPTIONS = { - provider: [(display, value) for display, value in options] - for provider, options in CONFIG["SHALLOW_AGENT_OPTIONS"].items() + provider: [(display, model) for display, model in options] + for provider, options in _get_config_section( + CONFIG, "SHALLOW_AGENT_OPTIONS", dict + ).items() } TICKER_INPUT_EXAMPLES = "Examples: SPY, CNC.TO, 7203.T, 0700.HK" diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index c5a9e348..c1778a39 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -3,11 +3,37 @@ import os from pathlib import Path CONFIG_PATH = Path(__file__).resolve().parents[1] / "config.json" -with CONFIG_PATH.open("r", encoding="utf-8") as config_file: - CONFIG = json.load(config_file) -DEFAULT_LLM_SETTINGS = CONFIG.get("DEFAULT_LLM_SETTINGS", {}) -BASE_URLS = {display.lower(): url for display, url in CONFIG.get("BASE_URLS", [])} + +def _load_config() -> dict: + try: + with CONFIG_PATH.open("r", encoding="utf-8") as config_file: + config = json.load(config_file) + except FileNotFoundError as exc: + raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc + except json.JSONDecodeError as exc: + raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc + except OSError as exc: + raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc + + if not isinstance(config, dict): + raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}") + return config + + +def _get_section(config: dict, key: str, expected_type: type): + section = config.get(key) + if not isinstance(section, expected_type): + raise RuntimeError(f"Invalid or missing '{key}' in config file: {CONFIG_PATH}") + return section + + +CONFIG = _load_config() +DEFAULT_LLM_SETTINGS = _get_section(CONFIG, "DEFAULT_LLM_SETTINGS", dict) +BASE_URLS = { + display.lower(): url + for display, url in _get_section(CONFIG, "BASE_URLS", list) +} DEFAULT_PROVIDER = DEFAULT_LLM_SETTINGS.get("llm_provider", "openai").lower() DEFAULT_BACKEND_URL = BASE_URLS.get( DEFAULT_PROVIDER, BASE_URLS.get("openai", "https://api.openai.com/v1") diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 0a47ec3e..3d4ae1a5 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -28,14 +28,46 @@ _PASSTHROUGH_KWARGS = ( ) CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json" -with CONFIG_PATH.open("r", encoding="utf-8") as config_file: - CONFIG = json.load(config_file) + + +def _load_config() -> dict: + try: + with CONFIG_PATH.open("r", encoding="utf-8") as config_file: + config = json.load(config_file) + except FileNotFoundError as exc: + raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc + except json.JSONDecodeError as exc: + raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc + except OSError as exc: + raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc + + if not isinstance(config, dict): + raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}") + return config + + +def _get_base_urls(config: dict) -> dict[str, str]: + base_urls = config.get("BASE_URLS") + if not isinstance(base_urls, list): + raise RuntimeError(f"Invalid or missing 'BASE_URLS' in config file: {CONFIG_PATH}") + + mapped_urls: dict[str, str] = {} + for item in base_urls: + if ( + isinstance(item, list) + and len(item) == 2 + and isinstance(item[0], str) + and isinstance(item[1], str) + ): + mapped_urls[item[0].lower()] = item[1] + return mapped_urls + + +CONFIG = _load_config() load_dotenv() -_BASE_URLS = { - display.lower(): url for display, url in CONFIG.get("BASE_URLS", []) -} +_BASE_URLS = _get_base_urls(CONFIG) _PROVIDER_BASE_URLS = { "xai": _BASE_URLS.get("xai", "https://api.x.ai/v1"), "openrouter": _BASE_URLS.get("openrouter", "https://openrouter.ai/api/v1"), @@ -68,7 +100,7 @@ class OpenAIClient(BaseLLMClient): def get_llm(self) -> Any: """Return configured ChatOpenAI instance.""" - llm_kwargs = {"model": self.model} + llm_kwargs: dict[str, Any] = {"model": self.model} # Provider-specific base URL and auth if self.provider in _PROVIDER_BASE_URLS: