Source code for minerva.utils.instantiators

from typing import Optional, Union, Type, TypeVar
from jsonargparse import ArgumentParser
import yaml
import json
from minerva.utils.typing import PathLike
from pathlib import Path


[docs] class ParserException(Exception): """Custom exception for parser errors.""" pass
T = TypeVar("T")
[docs] def _instantiate_cls( cls: Type[T], config_dict: dict, additional_kwargs: Optional[dict] = None, ) -> T: """Instantiate a class from a configuration dictionary. This function uses the `ArgumentParser` from `jsonargparse` to parse the configuration dictionary and instantiate the class with the provided arguments. It also allows for additional keyword arguments to be passed. Parameters ---------- cls : Type[T] The class to instantiate. config_dict : dict The configuration dictionary containing the parameters for the class, following the structure expected by `jsonargparse`. additional_kwargs : Optional[dict], optional Additional arguments that override or extend the configuration dictionary. It uses dot notation for nested parameters, e.g., `{"nested.param": "value"}`. Returns ------- T An instance of the class `cls` initialized with the parameters from `config_dict` and `additional_kwargs`. Raises ------ ParserException If there is an error during parsing or instantiation, a `ParserException` is raised. """ arg_name = "value" parser = ArgumentParser() parser.add_argument(arg_name, type=cls, help=f"Instantiator for {cls.__name__}") args = [str(config_dict)] if additional_kwargs is not None: for k, v in additional_kwargs.items(): if not k.startswith("--"): k = f"--{arg_name}.{k}" args.append(f"{k}={v}") try: parsed_config = parser.parse_args(args) instantiated_cls = parser.instantiate_classes(parsed_config) except SystemExit as e: raise ParserException(f"Error instantiating class {cls.__name__}") return instantiated_cls.get(arg_name)
[docs] def instantiate_cls( cls: Type[T], config: Union[dict, PathLike], additional_kwargs: Optional[dict] = None, ) -> T: """Instantiate a class from a configuration dictionary. This function uses the `ArgumentParser` from `jsonargparse` to parse the configuration dictionary and instantiate the class with the provided arguments. It also allows for additional keyword arguments to be passed. Parameters ---------- cls : Type[T] The class to instantiate. config : Union[dict, PathLike] The configuration dictionary containing the parameters for the class, following the structure expected by `jsonargparse`. It can also be a path to a JSON or YAML file containing the configuration. If a path is provided, it will read the file and parse its contents into a dictionary. Supported formats are JSON (`.json`) and YAML (`.yaml`, `.yml`). additional_kwargs : Optional[dict], optional Additional arguments that override or extend the configuration dictionary. It uses dot notation for nested parameters, e.g., `{"nested.param": "value"}`. Returns ------- T An instance of the class `cls` initialized with the parameters from `config_dict` and `additional_kwargs`. Raises ------ ParserException If there is an error during parsing or instantiation, a `ParserException` is raised. FileNotFoundError If the provided path does not exist when a path-like object is given. ValueError If the provided configuration is neither a dictionary nor a valid path-like object, or if the file format is unsupported. Examples -------- >>> from minerva.utils.instantiators import instantiate_cls >>> from minerva.models.nets.base import SimpleSupervisedModel >>> model_config = { "class_path": "minerva.models.nets.base.SimpleSupervisedModel", "init_args": { "backbone": { "class_path": "minerva.models.nets.time_series.cnns.CNN_PF_Backbone", "init_args": {"include_middle": True}, }, "fc": { "class_path": "minerva.models.nets.mlp.MLP", "init_args": {"layer_sizes": [768, 128, 6]}, }, "loss_fn": {"class_path": "torch.nn.CrossEntropyLoss"}, "flatten": True, }, } >>> model = instantiate_cls(SimpleSupervisedModel, model_config) """ if isinstance(config, dict): config_dict = config elif isinstance(config, PathLike): path = Path(config) if not path.exists(): raise FileNotFoundError(f"Config file {path} does not exist.") if path.suffix == ".json": with open(path, "r") as f: config_dict = json.load(f) elif path.suffix in [".yaml", ".yml"]: with open(path, "r") as f: config_dict = yaml.safe_load(f) else: raise ValueError(f"Unsupported config file format: {path.suffix}") else: raise ValueError(f"Unsupported config type: {type(config)}.") return _instantiate_cls( cls=cls, config_dict=config_dict, additional_kwargs=additional_kwargs, )