Source code for pimmslearn.nb

import logging
from pathlib import Path
from pprint import pformat

import yaml

import pimmslearn.io

logger = logging.getLogger()


[docs] class Config: """Config class with a setter enforcing that config entries cannot be overwritten. Can contain configs, which are itself configs: keys, paths, """ def __setattr__(self, entry, value): """Set if attribute not in instance.""" if hasattr(self, entry) and getattr(self, entry) != value: raise AttributeError(f"{entry} already set to {getattr(self, entry)}") super().__setattr__(entry, value) def __repr__(self): return pformat(vars(self)) # does not work in Jupyter?
[docs] def overwrite_entry(self, entry, value): """Explicitly overwrite a given value.""" super().__setattr__(entry, value)
[docs] def dump(self, fname=None): if fname is None: try: fname = self.out_folder fname = Path(fname) / "model_config.yml" except AttributeError as e: raise AttributeError( 'Specify fname or set "out_folder" attribute.' ) from e d = pimmslearn.io.parse_dict(input_dict=self.__dict__) with open(fname, "w") as f: yaml.dump(d, f) logger.info(f"Dumped config to: {fname}")
[docs] @classmethod def from_dict(cls, d: dict): cfg = cls() for k, v in d.items(): setattr(cfg, k, v) return cfg
[docs] def update_from_dict(self, params: dict): for k, v in params.items(): try: setattr(self, k, v) except AttributeError: logger.info(f"Already set attribute: {k} has value {v}")
[docs] def keys(self): return vars(self).keys()
[docs] def items(self): return vars(self).items()
[docs] def values(self): return vars(self).values()
[docs] def get_params(args: dict.keys, globals, remove=True) -> dict: params = {k: v for k, v in globals.items() if k not in args and k[0] != "_"} if not remove: return params remove_keys_from_globals(params.keys(), globals=globals) return params
[docs] def remove_keys_from_globals(keys: dict.keys, globals: dict): for k in keys: try: del globals[k] logger.info(f"Removed from global namespace: {k}") except KeyError: logger.warning(f"Key not found in globals(): {k}")
[docs] def add_default_paths(cfg: Config, folder_data="", out_root=None): """Add default paths to config.""" if out_root: cfg.out_folder = Path(out_root) cfg.out_folder.mkdir(exist_ok=True, parents=True) else: cfg.out_folder = cfg.folder_experiment if folder_data: cfg.data = Path(folder_data) else: cfg.data = cfg.folder_experiment / "data" cfg.data.mkdir(exist_ok=True, parents=True) assert cfg.data.exists(), f"Directory not found: {cfg.data}" del folder_data cfg.out_figures = cfg.folder_experiment / "figures" cfg.out_figures.mkdir(exist_ok=True) cfg.out_metrics = cfg.folder_experiment cfg.out_metrics.mkdir(exist_ok=True) cfg.out_models = cfg.folder_experiment cfg.out_models.mkdir(exist_ok=True) cfg.out_preds = cfg.folder_experiment / "preds" cfg.out_preds.mkdir(exist_ok=True) return cfg
[docs] def args_from_dict(args: dict) -> Config: assert "folder_experiment" in args, f'Specify "folder_experiment" in {args}.' args["folder_experiment"] = Path(args["folder_experiment"]) args = Config().from_dict(args) args.folder_experiment.mkdir(exist_ok=True, parents=True) add_default_paths( args, folder_data=args.__dict__.get("folder_data", ""), out_root=args.__dict__.get("out_root", None), ) return args