Source code for scvi.inference.inference

import logging
import copy
from typing import Union, List

import matplotlib.pyplot as plt
import torch
from numpy import ceil

from scvi.dataset import GeneExpressionDataset
from scvi.inference import Trainer

plt.switch_backend("agg")
logger = logging.getLogger(__name__)


[docs]class UnsupervisedTrainer(Trainer): """Class for unsupervised training of an autoencoder. Parameters ---------- model A model instance from class ``VAE``, ``VAEC``, ``SCANVI``, ``AutoZIVAE`` gene_dataset A gene_dataset instance like ``CortexDataset()`` train_size The train size, a float between 0 and 1 representing proportion of dataset to use for training to use Default: ``0.9``. test_size The test size, a float between 0 and 1 representing proportion of dataset to use for testing to use Default: ``None``, which is equivalent to data not in the train set. If ``train_size`` and ``test_size`` do not add to 1 then the remaining samples are added to a ``validation_set``. **kwargs Other keywords arguments from the general Trainer class. Other Parameters ---------------- n_epochs_kl_warmup Number of epochs for linear warmup of KL(q(z|x)||p(z)) term. After `n_epochs_kl_warmup`, the training objective is the ELBO. This might be used to prevent inactivity of latent units, and/or to improve clustering of latent space, as a long warmup turns the model into something more of an autoencoder. Be aware that large datasets should avoid this mode and rely on n_iter_kl_warmup. If this parameter is not None, then it overrides any choice of `n_iter_kl_warmup`. n_iter_kl_warmup Number of iterations for warmup (useful for bigger datasets) int(128*5000/400) is a good default value. normalize_loss A boolean determining whether the loss is divided by the total number of samples used for training. In particular, when the global KL divergence is equal to 0 and the division is performed, the loss for a minibatchis is equal to the average of reconstruction losses and KL divergences on the minibatch. Default: ``None``, which is equivalent to setting False when the model is an instance from class ``AutoZIVAE`` and True otherwise. Examples -------- >>> gene_dataset = CortexDataset() >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, ... n_labels=gene_dataset.n_labels) >>> infer = VariationalInference(gene_dataset, vae, train_size=0.5) >>> infer.train(n_epochs=20, lr=1e-3) Notes ----- Two parameters can help control the training KL annealing If your applications rely on the posterior quality, (i.e. differential expression, batch effect removal), ensure the number of total epochs (or iterations) exceed the number of epochs (or iterations) used for KL warmup """ default_metrics_to_monitor = ["elbo"] def __init__( self, model, gene_dataset: GeneExpressionDataset, train_size: Union[int, float] = 0.9, test_size: Union[int, float] = None, n_iter_kl_warmup: Union[int, None] = None, n_epochs_kl_warmup: Union[int, None] = 400, normalize_loss: bool = None, **kwargs ): train_size = float(train_size) if train_size > 1.0 or train_size <= 0.0: raise ValueError( "train_size needs to be greater than 0 and less than or equal to 1" ) super().__init__(model, gene_dataset, **kwargs) # Set up number of warmup iterations self.n_iter_kl_warmup = n_iter_kl_warmup self.n_epochs_kl_warmup = n_epochs_kl_warmup self.normalize_loss = ( not ( hasattr(self.model, "reconstruction_loss") and self.model.reconstruction_loss == "autozinb" ) if normalize_loss is None else normalize_loss ) # Total size of the dataset used for training # (e.g. training set in this class but testing set in AdapterTrainer). # It used to rescale minibatch losses (cf. eq. (8) in Kingma et al., Auto-Encoding Variational Bayes, iCLR 2013) self.n_samples = 1.0 if type(self) is UnsupervisedTrainer: ( self.train_set, self.test_set, self.validation_set, ) = self.train_test_validation(model, gene_dataset, train_size, test_size) self.train_set.to_monitor = ["elbo"] self.test_set.to_monitor = ["elbo"] self.validation_set.to_monitor = ["elbo"] self.n_samples = len(self.train_set.indices) @property def posteriors_loop(self): return ["train_set"]
[docs] def loss(self, tensors: List, feed_labels: bool = True): # The next lines should not be modified, because scanVI's trainer inherits # from this class and should NOT include label information to compute the ELBO by default sample_batch, local_l_mean, local_l_var, batch_index, y = tensors if not feed_labels: y = None reconst_loss, kl_divergence_local, kl_divergence_global = self.model( sample_batch, local_l_mean, local_l_var, batch_index, y ) loss = ( self.n_samples * torch.mean(reconst_loss + self.kl_weight * kl_divergence_local) + kl_divergence_global ) if self.normalize_loss: loss = loss / self.n_samples return loss
@property def kl_weight(self): epoch_criterion = self.n_epochs_kl_warmup is not None iter_criterion = self.n_iter_kl_warmup is not None if epoch_criterion: kl_weight = min(1.0, self.epoch / self.n_epochs_kl_warmup) elif iter_criterion: kl_weight = min(1.0, self.n_iter / self.n_iter_kl_warmup) else: kl_weight = 1.0 return kl_weight
[docs] def on_training_begin(self): epoch_criterion = self.n_epochs_kl_warmup is not None iter_criterion = self.n_iter_kl_warmup is not None if epoch_criterion: log_message = "KL warmup for {} epochs".format(self.n_epochs_kl_warmup) if self.n_epochs_kl_warmup > self.n_epochs: logger.info( "KL warmup phase exceeds overall training phase" "If your applications rely on the posterior quality, " "consider training for more epochs or reducing the kl warmup." ) elif iter_criterion: log_message = "KL warmup for {} iterations".format(self.n_iter_kl_warmup) n_iter_per_epochs_approx = ceil( self.gene_dataset.nb_cells / self.batch_size ) n_total_iter_approx = self.n_epochs * n_iter_per_epochs_approx if self.n_iter_kl_warmup > n_total_iter_approx: logger.info( "KL warmup phase may exceed overall training phase." "If your applications rely on posterior quality, " "consider training for more epochs or reducing the kl warmup." ) else: log_message = "Training without KL warmup" logger.info(log_message)
[docs] def on_training_end(self): if self.kl_weight < 0.99: logger.info( "Training is still in warming up phase. " "If your applications rely on the posterior quality, " "consider training for more epochs or reducing the kl warmup." )
[docs]class AdapterTrainer(UnsupervisedTrainer): def __init__(self, model, gene_dataset, posterior_test, frequency=5): super().__init__(model, gene_dataset, frequency=frequency) self.test_set = posterior_test self.test_set.to_monitor = ["elbo"] self.params = list(self.model.z_encoder.parameters()) + list( self.model.l_encoder.parameters() ) self.z_encoder_state = copy.deepcopy(model.z_encoder.state_dict()) self.l_encoder_state = copy.deepcopy(model.l_encoder.state_dict()) self.n_scale = len(self.test_set.indices) @property def posteriors_loop(self): return ["test_set"]
[docs] def train(self, n_path=10, n_epochs=50, **kwargs): for i in range(n_path): # Re-initialize to create new path self.model.z_encoder.load_state_dict(self.z_encoder_state) self.model.l_encoder.load_state_dict(self.l_encoder_state) super().train(n_epochs, params=self.params, **kwargs) return min(self.history["elbo_test_set"])