Source code for pimmslearn.models.collect_dumps

"""Collects metrics and config files from the experiment directory structure."""

import json
import logging
from functools import partial, update_wrapper
from pathlib import Path
from typing import Callable, Iterable

import yaml

import pimmslearn.pandas

logger = logging.getLogger(__name__)


[docs] def select_content(s: str, first_split): s = s.split(first_split)[1] assert isinstance(s, str), f"More than one split: {s}" entries = s.split("_") if len(entries) > 1: s = "_".join(entries[:-1]) return s
[docs] def load_config_file(fname: Path, first_split="config_") -> dict: with open(fname) as f: loaded = yaml.safe_load(f) key = f"{fname.parent.name}_{select_content(fname.stem, first_split=first_split)}" return key, loaded
[docs] def load_metric_file(fname: Path, first_split="metrics_") -> dict: with open(fname) as f: loaded = json.load(f) loaded = pimmslearn.pandas.flatten_dict_of_dicts(loaded) key = f"{fname.parent.name}_{select_content(fname.stem, first_split=first_split)}" return key, loaded
[docs] def collect( paths: Iterable, load_fn: Callable[[Path], dict], ) -> dict: all_metrics = [] for fname in paths: fname = Path(fname) key, loaded = load_fn(fname) logger.debug(f"{key = }") if "id" not in loaded: loaded["id"] = key all_metrics.append(loaded) return all_metrics
collect_metrics = partial( collect, load_fn=load_metric_file, ) collect_metrics = update_wrapper(collect_metrics, collect) collect_configs = partial( collect, load_fn=load_config_file, ) collect_configs = update_wrapper(collect_configs, collect)