Source code for pimmslearn.models.vae

"""VAE implementation based on https://github.com/ronaldiscool/VAETutorial

Adapted to the setup of learning missing values.

- funnel architecture (or fixed hidden layer layout)
- loss is adapted to Dataset and FastAI adaptions
- batchnorm1D for now (not weight norm)
"""

from typing import List

import torch
import torch.nn.functional as F
from torch import nn

leaky_relu_default = nn.LeakyReLU(0.1)

PI = torch.tensor(torch.pi)
log_of_2 = torch.log(torch.tensor(2.0))


[docs] class VAE(nn.Module): def __init__( self, n_features: int, n_neurons: List[int], activation=leaky_relu_default, # last_encoder_activation=leaky_relu_default, last_decoder_activation=None, dim_latent: int = 10, ): super().__init__() # set up hyperparameters self.n_features, self.n_neurons = n_features, list(n_neurons) self.layers = [n_features, *self.n_neurons] self.dim_latent = dim_latent # define architecture hidden layer def build_layer(in_feat, out_feat): return [ nn.Linear(in_feat, out_feat), nn.Dropout(0.2), nn.BatchNorm1d(out_feat), activation, ] # Encoder self.encoder = [] for i in range(len(self.layers) - 1): in_feat, out_feat = self.layers[i : i + 2] self.encoder.extend(build_layer(in_feat=in_feat, out_feat=out_feat)) self.encoder.append(nn.Linear(out_feat, dim_latent * 2)) self.encoder = nn.Sequential(*self.encoder) # Decoder self.layers_decoder = self.layers[::-1] assert self.layers_decoder is not self.layers assert out_feat == self.layers_decoder[0] self.decoder = build_layer(in_feat=self.dim_latent, out_feat=out_feat) i = -1 # in case a single hidden layer is passed for i in range(len(self.layers_decoder) - 2): in_feat, out_feat = self.layers_decoder[i : i + 2] self.decoder.extend(build_layer(in_feat=in_feat, out_feat=out_feat)) in_feat, out_feat = self.layers_decoder[i + 1 : i + 3] self.decoder.append(nn.Linear(in_feat, out_feat * 2)) if last_decoder_activation is not None: self.append(last_decoder_activation) self.decoder = nn.Sequential(*self.decoder)
[docs] def encode(self, x): z_params = self.encoder(x) z_mu = z_params[:, : self.dim_latent] z_logvar = z_params[:, self.dim_latent :] return z_mu, z_logvar
[docs] def get_mu_and_logvar(self, x, detach=False): return self.encode(x)
[docs] def decode(self, z): x_params = self.decoder(z) x_mu = x_params[:, : self.n_features] x_logvar = x_params[:, self.n_features :] return x_mu, x_logvar
[docs] def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) return mu + torch.randn_like(std) * std
[docs] def forward(self, x): z_mu, z_logvar = self.encode(x) z = self.reparameterize(z_mu, z_logvar) x_mu, x_logvar = self.decode(z) return x_mu, x_logvar, z_mu, z_logvar
[docs] def compute_kld(z_mu, z_logvar): return 0.5 * (z_mu**2 + torch.exp(z_logvar) - 1 - z_logvar)
[docs] def gaussian_log_prob(z, mu, logvar): return -0.5 * (torch.log(2.0 * PI) + logvar + (z - mu) ** 2 / torch.exp(logvar))
[docs] def loss_fct(pred, y, reduction="sum", results: List = None, freebits=0.1): x_mu, x_logvar, z_mu, z_logvar = pred batch = y l_rec = -torch.sum(gaussian_log_prob(batch, x_mu, x_logvar)) l_reg = torch.sum( ( F.relu(compute_kld(z_mu, z_logvar) - freebits * log_of_2) + freebits * log_of_2 ), 1, ) if results is not None: results.append((l_rec.item(), torch.mean(l_reg).item())) return l_rec / l_reg.shape[0] + torch.mean(l_reg)