Source code for pimmslearn.models

import json
import logging
import pickle
import pprint
from functools import reduce
from operator import mul
from pathlib import Path
from typing import Callable, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.metrics as sklm
import torch
from fastai import learner
from fastcore.foundation import L

import pimmslearn
from pimmslearn.models import ae, analysis, collab, vae

logger = logging.getLogger(__name__)

NUMPY_ONE = np.int64(1)

__all__ = [
    "ae",
    "analysis",
    "collab",
    "vae",
    "plot_loss",
    "plot_training_losses",
    "calc_net_weight_count",
    "RecorderDump",
    "split_prediction_by_mask",
    "compare_indices",
    "collect_metrics",
    "calculte_metrics",
    "Metrics",
    "get_df_from_nested_dict",
]


[docs] def plot_loss( recorder: learner.Recorder, norm_train: np.int64 = NUMPY_ONE, norm_val: np.int64 = NUMPY_ONE, skip_start: int = 5, with_valid: bool = True, ax: plt.Axes = None, ) -> plt.Axes: """Adapted Recorder.plot_loss to accept matplotlib.axes.Axes argument. Allows to build combined graphics. Parameters ---------- recorder : learner.Recorder fastai Recorder object, learn.recorder norm_train: np.int64, optional Normalize epoch loss by number of training samples, by default 1 norm_val: np.int64, optional Normalize epoch loss by number of validation samples, by default 1 skip_start : int, optional Skip N first batch metrics, by default 5 with_valid : bool, optional Add validation data loss, by default True ax : plt.Axes, optional Axes to plot on, by default None Returns ------- plt.Axes [description] """ if not ax: _, ax = plt.subplots() ax.plot( list(range(skip_start, len(recorder.losses))), recorder.losses[skip_start:] / norm_train, label="train", ) if with_valid: idx = (np.array(recorder.iters) < skip_start).sum() ax.plot( recorder.iters[idx:], L(recorder.values[idx:]).itemgot(1) / norm_val, label="valid", ) ax.legend() return ax
NORM_ONES = np.array([1, 1], dtype="int")
[docs] def plot_training_losses( learner: learner.Learner, name: str, ax=None, norm_factors=NORM_ONES, folder=None, figsize=(15, 8), ): if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() ax.set_title(f"{name} loss") norm_train, norm_val = norm_factors # exactly two with_valid = True if norm_val is None: with_valid = False learner.recorder.plot_loss( skip_start=5, ax=ax, with_valid=with_valid, norm_train=norm_train, norm_val=norm_val, ) if folder is not None: name = name.lower() _ = RecorderDump(learner.recorder, name).save(folder) pimmslearn.savefig(fig, name=f"{name}_training", folder=folder) return fig
[docs] def calc_net_weight_count(model: torch.nn.modules.module.Module) -> int: model.train() model_params = filter(lambda p: p.requires_grad, model.parameters()) weight_count = 0 for param in model_params: weight_count += np.prod(param.size()) return int(weight_count)
[docs] class RecorderDump: """Simple Class to hold fastai Recorder Callback data for serialization using pickle.""" filename_tmp = "recorder_{}.pkl" def __init__(self, recorder, name): self.losses = recorder.losses self.values = recorder.values self.iters = recorder.iters self.name = name
[docs] def save(self, folder="."): with open(Path(folder) / self.filename_tmp.format(self.name), "wb") as f: pickle.dump(self, f)
[docs] @classmethod def load(cls, filepath, name): with open(Path(filepath) / cls.filename_tmp.format(name), "rb") as f: ret = pickle.load(f) return ret
plot_loss = plot_loss
[docs] def split_prediction_by_mask( pred: pd.DataFrame, mask: pd.DataFrame, check_keeps_all: bool = False ) -> Tuple[pd.DataFrame, pd.DataFrame]: """[summary] Parameters ---------- pred : pd.DataFrame prediction DataFrame mask : pd.DataFrame Mask with same indices as pred DataFrame. check_keeps_all : bool, optional if True, perform sanity checks, by default False Returns ------- Tuple[pd.DataFrame, pd.DataFrame] prediction for inversed mask, and predicitions for mask """ test_pred_observed = pred[~mask].stack() test_pred_real_na = pred[mask].stack() if check_keeps_all: assert len(test_pred_real_na) + len(test_pred_observed) == reduce( mul, pred.shape ) return test_pred_observed, test_pred_real_na
[docs] def compare_indices(first_index: pd.Index, second_index: pd.Index) -> pd.Index: """Show difference of indices in other index wrt. to first. First should be the larger collection wrt to the second. This is the set difference of two Index objects. If second index is a superset of indices of the first, the set will be empty, although there are differences (default behaviour in pandas). Parameters ---------- first_index : pd.Index Index, should be superset second_index : pd.Index Index, should be the subset Returns ------- pd.Index Return a new Index with elements of the first index not in second. """ _diff_index = first_index.difference(second_index) if len(_diff_index): print( "Some predictions couldn't be generated using the approach using artifical replicates.\n" "These will be omitted for evaluation." ) for _index in _diff_index: print(f"{_index[0]:<40}\t {_index[1]:<40}") return _diff_index
scoring = [("MSE", sklm.mean_squared_error), ("MAE", sklm.mean_absolute_error)]
[docs] def collect_metrics(metrics_jsons: List, key_fct: Callable) -> dict: """Collect and aggregate a bunch of json metrics. Parameters ---------- metrics_jsons : List list of filepaths to json metric files key_fct : Callable Callable which creates key function of a single filepath Returns ------- dict Aggregated metrics dictionary with outer key defined by key_fct Raises ------ AssertionError: If key should be overwritten, but value would change. """ all_metrics = {} for fname in metrics_jsons: fname = Path(fname) logger.info(f"Load file: {fname = }") key = key_fct(fname) # level, repeat logger.debug(f"{key = }") with open(fname) as f: loaded = json.load(f) loaded = pimmslearn.pandas.flatten_dict_of_dicts(loaded) if key not in all_metrics: all_metrics[key] = loaded continue for k, v in loaded.items(): if k in all_metrics[key]: logger.debug(f"Found existing key: {k = } ") assert ( all_metrics[key][k] == v ), "Diverging values for {k}: {v1} vs {v2}".format( k=k, v1=all_metrics[key][k], v2=v ) else: all_metrics[key][k] = v return all_metrics
[docs] def calculte_metrics( pred_df: pd.DataFrame, true_col: List[str] = None, scoring: List[Tuple[str, Callable]] = scoring, ) -> dict: """Create metrics based on predictions, a truth reference and a list of scoring function with a name. Parameters ---------- pred_df : pd.DataFrame Prediction DataFrame containing `true_col`. true_col : List[str], optional Column of ground truth values, by default None scoring : List[Tuple[str, Callable]], optional List of tuples. A tuple is a set of (key, funtion) pairs. The function take y_true and y_pred - as for all sklearn metrics, by default scoring Returns ------- pd.DataFrame [description] Raises ------ ValueError [description] """ if not true_col: # assume first column is truth if None is given y_true = pred_df.iloc[:, 0] print(f"Selected as truth to compare to: {y_true.name}") y_pred = pred_df.iloc[:, 1:] else: if issubclass(type(true_col), int): y_true = pred_df.iloc[:, true_col] y_pred = pred_df.drop(y_true.name, axis=1) elif issubclass(type(true_col), str): y_true = pred_df[true_col] y_pred = pred_df.drop(true_col, axis=1) else: raise ValueError( f"true_col has to be of type str or int, not {type(true_col)}" ) if y_true.isna().any(): raise ValueError( f"Ground truth column '{y_true.name}' contains missing values. " "Drop these rows first." ) metrics = {} for model_key in y_pred: model_pred = y_pred[model_key] model_pred_no_na = model_pred.dropna() if len(model_pred) > len(model_pred_no_na): logger.info( f"Drop indices for {model_key}: " "{}".format( [ (idx[0], idx[1]) for idx in model_pred.index.difference(model_pred_no_na.index) ] ) ) metrics[model_key] = dict( [ ( k, float( f( y_true=y_true.loc[model_pred_no_na.index], y_pred=model_pred_no_na, ) ), ) for k, f in scoring ] ) metrics[model_key]["N"] = int(len(model_pred_no_na)) metrics[model_key]["prop"] = len(model_pred_no_na) / len(model_pred) return metrics
[docs] class Metrics: def __init__(self): self.metrics = {}
[docs] def add_metrics(self, pred, key): self.metrics[key] = calculte_metrics(pred_df=pred.dropna()) return self.metrics[key]
def __repr__(self): return pprint.pformat(self.metrics, indent=2, compact=True)
[docs] def get_df_from_nested_dict( nested_dict, column_levels=("data_split", "model", "metric_name"), row_name="subset" ): metrics = {} for k, run_metrics in nested_dict.items(): metrics[k] = pimmslearn.pandas.flatten_dict_of_dicts(run_metrics) metrics = pd.DataFrame.from_dict(metrics, orient="index") metrics.columns.names = column_levels metrics.index.name = row_name return metrics