Source code for pimmslearn.plotting.errors
"""Plot errors based on DataFrame with model predictions."""
from __future__ import annotations
from typing import Optional
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from seaborn.categorical import EstimateAggregator
import pimmslearn.pandas.calc_errors
[docs]
def plot_errors_binned(
pred: pd.DataFrame,
target_col="observed",
ax: Axes = None,
palette: dict = None,
metric_name: Optional[str] = None,
errwidth: float = 1.2,
) -> Axes:
assert (
target_col in pred.columns
), f"Specify `target_col` parameter, `pred` do no contain: {target_col}"
models_order = pred.columns.to_list()
models_order.remove(target_col)
errors_binned = pimmslearn.pandas.calc_errors.calc_errors_per_bin(
pred=pred, target_col=target_col
)
meta_cols = ["bin", "n_obs"] # calculated along binned error
len_max_bin = len(str(int(errors_binned["bin"].max())))
n_obs = (
errors_binned[meta_cols]
.apply(lambda x: f"{x.bin:0{len_max_bin}}\n(N={x.n_obs:,d})", axis=1)
.rename("intensity bin")
.astype("category")
)
metric_name = metric_name or "Average error"
errors_binned = (
errors_binned[models_order]
.stack()
.to_frame(metric_name)
.join(n_obs)
.reset_index()
)
ax = sns.barplot(
data=errors_binned,
ax=ax,
x="intensity bin",
y=metric_name,
hue="model",
palette=palette,
errwidth=errwidth,
)
ax.xaxis.set_tick_params(rotation=90)
return ax, errors_binned
[docs]
def plot_errors_by_median(
pred: pd.DataFrame,
feat_medians: pd.Series,
target_col="observed",
ax: Axes = None,
palette: dict = None,
feat_name: str = None,
metric_name: Optional[str] = None,
errwidth: float = 1.2,
) -> tuple[Axes, pd.DataFrame]:
# calculate absolute errors
errors = pimmslearn.pandas.get_absolute_error(pred, y_true=target_col)
errors.columns.name = "model"
# define bins by integer value of median feature intensity
feat_medians = feat_medians.astype(int).rename("bin")
# number of intensities per bin
n_obs = pred[target_col].to_frame().join(feat_medians)
n_obs = n_obs.groupby("bin").size().to_frame("n_obs")
errors = (errors.stack().to_frame(metric_name).join(feat_medians)).reset_index()
n_obs.index.name = "bin"
errors = errors.join(n_obs, on="bin")
if feat_name is None:
feat_name = feat_medians.index.name
if not feat_name:
feat_name = "feature"
x_axis_name = f"intensity binned by median of {feat_name}"
len_max_bin = len(str(int(errors["bin"].max())))
errors[x_axis_name] = (
errors[["bin", "n_obs"]]
.apply(lambda x: f"{x.bin:0{len_max_bin}}\n(N={x.n_obs:,d})", axis=1)
.rename("intensity bin")
.astype("category")
)
metric_name = metric_name or "Average error"
sns.barplot(
data=errors,
ax=ax,
x=x_axis_name,
y=metric_name,
hue="model",
palette=palette,
errwidth=errwidth,
)
ax.xaxis.set_tick_params(rotation=90)
return ax, errors
[docs]
def get_data_for_errors_by_median(
errors: pd.DataFrame,
feat_name: str,
metric_name: str,
model_column: str = "model",
seed: int = 42,
) -> pd.DataFrame:
"""Extract Bars with confidence intervals from seaborn plot for seaborn 0.13 and above.
Confident intervals are calculated with bootstrapping(sampling the mean).
Parameters
----------
errors: pd.DataFrame
DataFrame created by `plot_errors_by_median` function
feat_name: str
feature name assigned(was transformed to 'intensity binned by median of {feat_name}')
metric_name: str
Metric used to calculate errors(MAE, MSE, etc) of intensities in bin
model_column: str
model_column in errors, defining model names
"""
x_axis_name = f"intensity binned by median of {feat_name}"
aggregator = EstimateAggregator("mean", ("ci", 95), n_boot=1_000, seed=seed)
# ! need to iterate over all models myself using groupby
ret = (
errors.groupby(
by=[
x_axis_name,
model_column,
],
observed=True,
)[[x_axis_name, model_column, metric_name]]
.apply(lambda df: aggregator(df, metric_name))
.reset_index()
)
ret.columns = ["bin", model_column, "mean", "ci_low", "ci_high"]
return ret
[docs]
def plot_rolling_error(
errors: pd.DataFrame,
metric_name: str,
window: int = 200,
min_freq=None,
freq_col: str = "freq",
colors_to_use=None,
ax=None,
):
errors_smoothed = (
errors.drop(freq_col, axis=1).rolling(window=window, min_periods=1).mean()
)
errors_smoothed_max = errors_smoothed.max().max()
errors_smoothed[freq_col] = errors[freq_col]
if min_freq is None:
min_freq = errors_smoothed[freq_col].min()
else:
errors_smoothed = errors_smoothed.loc[errors_smoothed[freq_col] > min_freq]
ax = errors_smoothed.plot(
x=freq_col,
ylabel=f"rolling average error ({metric_name})",
color=colors_to_use,
xlim=(min_freq, errors_smoothed[freq_col].max()),
ylim=(0, min(errors_smoothed_max, 5)),
ax=None,
)
return ax