From d1818de073b5e9d2d42fc7fcf44341e9fa8f42fd Mon Sep 17 00:00:00 2001 From: Maytekin Date: Tue, 24 Mar 2026 23:46:20 +0000 Subject: [PATCH] refactor(llm_clients): extract config loading logic into separate module Move config loading and validation functions from openai_client.py and factory.py into a new shared config_loader.py module. This centralizes configuration handling, reduces code duplication, and improves maintainability. The factory now gracefully falls back to default provider types if config loading fails. --- tradingagents/llm_clients/config_loader.py | 41 +++++++++++++++++ tradingagents/llm_clients/factory.py | 51 +++++++++++----------- tradingagents/llm_clients/openai_client.py | 43 ++---------------- 3 files changed, 69 insertions(+), 66 deletions(-) create mode 100644 tradingagents/llm_clients/config_loader.py diff --git a/tradingagents/llm_clients/config_loader.py b/tradingagents/llm_clients/config_loader.py new file mode 100644 index 00000000..78fe684b --- /dev/null +++ b/tradingagents/llm_clients/config_loader.py @@ -0,0 +1,41 @@ +import json +from pathlib import Path +from typing import Any + +CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json" + + +def load_config() -> dict: + try: + with CONFIG_PATH.open("r", encoding="utf-8") as config_file: + config = json.load(config_file) + except FileNotFoundError as exc: + raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc + except json.JSONDecodeError as exc: + raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc + except OSError as exc: + raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc + if not isinstance(config, dict): + raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}") + return config + + +def get_config_section(config: dict, key: str, expected_type: type) -> Any: + value = config.get(key) + if not isinstance(value, expected_type): + raise RuntimeError(f"Invalid or missing '{key}' in config file: {CONFIG_PATH}") + return value + + +def get_base_urls_map(config: dict) -> dict[str, str]: + base_urls = get_config_section(config, "BASE_URLS", list) + mapped_urls: dict[str, str] = {} + for item in base_urls: + if ( + isinstance(item, list) + and len(item) == 2 + and isinstance(item[0], str) + and isinstance(item[1], str) + ): + mapped_urls[item[0].lower()] = item[1] + return mapped_urls diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index fda6f024..7574f4b1 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -1,40 +1,39 @@ -import json -from pathlib import Path +import warnings from typing import Optional from .base_client import BaseLLMClient from .openai_client import OpenAIClient from .anthropic_client import AnthropicClient from .google_client import GoogleClient +from .config_loader import load_config, get_config_section -CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json" - - -def _load_config() -> dict: - try: - with CONFIG_PATH.open("r", encoding="utf-8") as config_file: - config = json.load(config_file) - except FileNotFoundError as exc: - raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc - except json.JSONDecodeError as exc: - raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc - except OSError as exc: - raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc - if not isinstance(config, dict): - raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}") - return config +DEFAULT_PROVIDER_TYPES: dict[str, str] = { + "openai": "openai", + "anthropic": "anthropic", + "google": "google", + "xai": "openai", + "openrouter": "openai", + "ollama": "openai", + "lmstudio": "openai", +} def _load_provider_types() -> dict[str, str]: - provider_types = _load_config().get("LLM_PROVIDER_TYPES") - if not isinstance(provider_types, dict): - raise RuntimeError( - f"Invalid or missing 'LLM_PROVIDER_TYPES' in config file: {CONFIG_PATH}" + try: + config = load_config() + provider_types = get_config_section(config, "LLM_PROVIDER_TYPES", dict) + return { + str(name).lower(): str(client_type).lower() + for name, client_type in provider_types.items() + } + except RuntimeError as exc: + warnings.warn( + f"Failed to load LLM_PROVIDER_TYPES from config.json: {exc}. " + "Using built-in provider mapping.", + RuntimeWarning, + stacklevel=2, ) - return { - str(name).lower(): str(client_type).lower() - for name, client_type in provider_types.items() - } + return DEFAULT_PROVIDER_TYPES.copy() _PROVIDER_TYPES: dict[str, str] | None = None diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 3d4ae1a5..fc17a9ae 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -1,12 +1,11 @@ -import json import os -from pathlib import Path from typing import Any, Optional from dotenv import load_dotenv from langchain_openai import ChatOpenAI from .base_client import BaseLLMClient, normalize_content +from .config_loader import load_config, get_base_urls_map from .validators import validate_model @@ -27,47 +26,11 @@ _PASSTHROUGH_KWARGS = ( "api_key", "callbacks", "http_client", "http_async_client", ) -CONFIG_PATH = Path(__file__).resolve().parents[2] / "config.json" - - -def _load_config() -> dict: - try: - with CONFIG_PATH.open("r", encoding="utf-8") as config_file: - config = json.load(config_file) - except FileNotFoundError as exc: - raise RuntimeError(f"Config file not found: {CONFIG_PATH}") from exc - except json.JSONDecodeError as exc: - raise RuntimeError(f"Invalid JSON in config file: {CONFIG_PATH}") from exc - except OSError as exc: - raise RuntimeError(f"Unable to read config file: {CONFIG_PATH}") from exc - - if not isinstance(config, dict): - raise RuntimeError(f"Invalid config format in file: {CONFIG_PATH}") - return config - - -def _get_base_urls(config: dict) -> dict[str, str]: - base_urls = config.get("BASE_URLS") - if not isinstance(base_urls, list): - raise RuntimeError(f"Invalid or missing 'BASE_URLS' in config file: {CONFIG_PATH}") - - mapped_urls: dict[str, str] = {} - for item in base_urls: - if ( - isinstance(item, list) - and len(item) == 2 - and isinstance(item[0], str) - and isinstance(item[1], str) - ): - mapped_urls[item[0].lower()] = item[1] - return mapped_urls - - -CONFIG = _load_config() +CONFIG = load_config() load_dotenv() -_BASE_URLS = _get_base_urls(CONFIG) +_BASE_URLS = get_base_urls_map(CONFIG) _PROVIDER_BASE_URLS = { "xai": _BASE_URLS.get("xai", "https://api.x.ai/v1"), "openrouter": _BASE_URLS.get("openrouter", "https://openrouter.ai/api/v1"),