Source code for pimmslearn.plotting.defaults
import logging
import matplotlib as mpl
import seaborn as sns
logger = logging.getLogger(__name__)
# ! default seaborn color map only has 10 colors
# https://seaborn.pydata.org/tutorial/color_palettes.html
# sns.color_palette("husl", N) to get N distinct colors
color_model_mapping = {
"KNN": sns.color_palette()[0],
"KNN_IMPUTE": sns.color_palette()[1],
"CF": sns.color_palette()[2],
"DAE": sns.color_palette()[3],
"VAE": sns.color_palette()[4],
"RF": sns.color_palette()[5],
"Median": sns.color_palette()[6],
"None": sns.color_palette()[7],
"BPCA": sns.color_palette()[8],
"MICE-CART": sns.color_palette()[9],
}
# other_colors = sns.color_palette()[8:]
other_colors = sns.color_palette("husl", 20)
color_model_mapping["IMPSEQ"] = other_colors[0]
color_model_mapping["QRILC"] = other_colors[1]
color_model_mapping["IMPSEQROB"] = other_colors[1]
color_model_mapping["MICE-NORM"] = other_colors[2]
color_model_mapping["SEQKNN"] = other_colors[3]
color_model_mapping["IMPSEQROB"] = other_colors[4]
color_model_mapping["GSIMP"] = other_colors[5]
color_model_mapping["MSIMPUTE"] = other_colors[6]
color_model_mapping["MSIMPUTE_MNAR"] = other_colors[7]
color_model_mapping["TRKNN"] = other_colors[8]
color_model_mapping["SVDMETHOD"] = other_colors[9]
other_colors = other_colors[10:]
[docs]
def assign_colors(models):
i = 0
ret_colors = list()
for model in models:
if model in color_model_mapping:
ret_colors.append(color_model_mapping[model])
else:
pos = i % len(other_colors)
ret_colors.append(other_colors[pos])
i += 1
if i > len(other_colors):
logger.info("Reused some colors!")
return ret_colors
[docs]
class ModelColorVisualizer:
def __init__(self, models, palette):
self.models = models
self.palette = map(mpl.colors.colorConverter.to_rgb, palette)
[docs]
def as_hex(self):
"""Return a color palette with hex codes instead of RGB values."""
hex = [mpl.colors.rgb2hex(rgb) for rgb in self.palette]
return hex
def _repr_html_(self):
"""Rich display of the color palette in an HTML frontend."""
s = 55
n = len(self.models)
html = f'<svg width="{s*2}" height="{s*n/2}">'
for i, (m, c) in enumerate(zip(self.models, self.as_hex())):
html += (
f'<rect x="0" y="{i * s /2}" width="{s*2}" height="{s/2}" style="fill:{c};'
'stroke-width:2;stroke:rgb(255,255,255)" metadata="tt"/>'
)
html += f'<text x="{4}" y="{(i * s / 2) + 20}" font-size="12" fill="black">{m}</text>'
html += "</svg>"
return html
labels_dict = {
"NA not interpolated valid_collab collab MSE": "MSE",
"batch_size": "bs",
"n_hidden_layers": "No. of hidden layers",
"latent_dim": "hidden layer dimension",
"subset_w_N": "subset",
"n_params": "no. of parameter",
"metric_value": "value",
"metric_name": "metric",
}