Source code for scvi.inference.trainer

import logging
import sys
import time
from abc import abstractmethod
from collections import defaultdict, OrderedDict
from itertools import cycle
from typing import List

import numpy as np
import torch
from sklearn.model_selection._split import _validate_shuffle_split
from torch.utils.data.sampler import SubsetRandomSampler

from scvi.dataset import GeneExpressionDataset
from scvi.inference.posterior import Posterior

from tqdm import tqdm

logger = logging.getLogger(__name__)


[docs]class Trainer: """The abstract Trainer class for training a PyTorch model and monitoring its statistics. It should be inherited at least with a ``.loss()`` function to be optimized in the training loop. Parameters ---------- model : A model instance from class ``VAE``, ``VAEC``, ``SCANVI`` gene_dataset : A gene_dataset instance like ``CortexDataset()`` use_cuda : Default: ``True``. metrics_to_monitor : A list of the metrics to monitor. If not specified, will use the ``default_metrics_to_monitor`` as specified in each . Default: ``None``. benchmark : if True, prevents statistics computation in the training. Default: ``False``. frequency : The frequency at which to keep track of statistics. Default: ``None``. early_stopping_metric : The statistics on which to perform early stopping. Default: ``None``. save_best_state_metric : The statistics on which we keep the network weights achieving the best store, and restore them at the end of training. Default: ``None``. on : The data_loader name reference for the ``early_stopping_metric`` and ``save_best_state_metric``, that should be specified if any of them is. Default: ``None``. show_progbar : If False, disables progress bar. seed : Random seed for train/test/validate split Returns ------- """ default_metrics_to_monitor = [] def __init__( self, model, gene_dataset: GeneExpressionDataset, use_cuda: bool = True, metrics_to_monitor: List = None, benchmark: bool = False, frequency: int = None, weight_decay: float = 1e-6, early_stopping_kwargs: dict = None, data_loader_kwargs: dict = None, show_progbar: bool = True, batch_size: int = 128, seed: int = 0, max_nans: int = 10, ): # Model, dataset management self.model = model self.gene_dataset = gene_dataset self._posteriors = OrderedDict() self.seed = seed # For train/test splitting self.use_cuda = use_cuda and torch.cuda.is_available() if self.use_cuda: self.model.cuda() # Data loader attributes self.batch_size = batch_size self.data_loader_kwargs = {"batch_size": batch_size, "pin_memory": use_cuda} data_loader_kwargs = data_loader_kwargs if data_loader_kwargs else dict() self.data_loader_kwargs.update(data_loader_kwargs) # Optimization attributes self.optimizer = None self.weight_decay = weight_decay self.n_epochs = None self.epoch = -1 # epoch = self.epoch + 1 in compute metrics self.training_time = 0 self.n_iter = 0 # Training NaNs handling self.max_nans = max_nans self.current_loss = None # torch.Tensor training loss self.previous_loss_was_nan = False self.nan_counter = 0 # Counts occuring NaNs during training # Metrics and early stopping self.compute_metrics_time = None if metrics_to_monitor is not None: self.metrics_to_monitor = set(metrics_to_monitor) else: self.metrics_to_monitor = set(self.default_metrics_to_monitor) early_stopping_kwargs = ( early_stopping_kwargs if early_stopping_kwargs else dict() ) self.early_stopping = EarlyStopping(**early_stopping_kwargs) self.benchmark = benchmark self.frequency = frequency if not benchmark else None self.history = defaultdict(list) self.best_state_dict = self.model.state_dict() self.best_epoch = self.epoch if self.early_stopping.early_stopping_metric: self.metrics_to_monitor.add(self.early_stopping.early_stopping_metric) self.show_progbar = show_progbar
[docs] @torch.no_grad() def compute_metrics(self): begin = time.time() epoch = self.epoch + 1 if self.frequency and ( epoch == 0 or epoch == self.n_epochs or (epoch % self.frequency == 0) ): with torch.set_grad_enabled(False): self.model.eval() logger.debug("\nEPOCH [%d/%d]: " % (epoch, self.n_epochs)) for name, posterior in self._posteriors.items(): message = " ".join([s.capitalize() for s in name.split("_")[-2:]]) if posterior.nb_cells < 5: logging.debug( message + " is too small to track metrics (<5 samples)" ) continue if hasattr(posterior, "to_monitor"): for metric in posterior.to_monitor: if metric not in self.metrics_to_monitor: logger.debug(message) result = getattr(posterior, metric)() self.history[metric + "_" + name] += [result] for metric in self.metrics_to_monitor: result = getattr(posterior, metric)() self.history[metric + "_" + name] += [result] self.model.train() self.compute_metrics_time += time.time() - begin
[docs] def train(self, n_epochs=400, lr=1e-3, eps=0.01, params=None, **extras_kwargs): begin = time.time() self.model.train() if params is None: params = filter(lambda p: p.requires_grad, self.model.parameters()) self.optimizer = torch.optim.Adam( params, lr=lr, eps=eps, weight_decay=self.weight_decay ) # Initialization of other model's optimizers self.training_extras_init(**extras_kwargs) self.compute_metrics_time = 0 self.n_epochs = n_epochs self.compute_metrics() self.on_training_begin() for self.epoch in tqdm( range(n_epochs), desc="training", disable=not self.show_progbar, file=sys.stdout, ): self.on_epoch_begin() for tensors_list in self.data_loaders_loop(): if tensors_list[0][0].shape[0] < 3: continue self.on_iteration_begin() # Update the model's parameters after seeing the data self.on_training_loop(tensors_list) # Checks the training status, ensures no nan loss self.on_iteration_end() # Computes metrics and controls early stopping if not self.on_epoch_end(): break if self.early_stopping.save_best_state_metric is not None: self.model.load_state_dict(self.best_state_dict) self.compute_metrics() self.model.eval() self.training_extras_end() self.training_time += (time.time() - begin) - self.compute_metrics_time if self.frequency: logger.debug( "\nTraining time: %i s. / %i epochs" % (int(self.training_time), self.n_epochs) ) self.on_training_end()
[docs] def on_training_loop(self, tensors_list): self.current_loss = loss = self.loss(*tensors_list) self.optimizer.zero_grad() loss.backward() self.optimizer.step()
[docs] def training_extras_init(self, **extras_kwargs): """Other necessary models to simultaneously train Parameters ---------- **extras_kwargs : Returns ------- """ pass
[docs] def training_extras_end(self): """Place to put extra models in eval mode, etc.""" pass
[docs] def on_training_begin(self): pass
[docs] def on_epoch_begin(self): # Epochs refer to a pass through the entire dataset (in minibatches) pass
[docs] def on_epoch_end(self): self.compute_metrics() on = self.early_stopping.on early_stopping_metric = self.early_stopping.early_stopping_metric save_best_state_metric = self.early_stopping.save_best_state_metric if save_best_state_metric is not None and on is not None: if self.early_stopping.update_state( self.history[save_best_state_metric + "_" + on][-1] ): self.best_state_dict = self.model.state_dict() self.best_epoch = self.epoch continue_training = True if early_stopping_metric is not None and on is not None: continue_training, reduce_lr = self.early_stopping.update( self.history[early_stopping_metric + "_" + on][-1] ) if reduce_lr: logger.info("Reducing LR on epoch {}.".format(self.epoch)) for param_group in self.optimizer.param_groups: param_group["lr"] *= self.early_stopping.lr_factor return continue_training
[docs] def on_iteration_begin(self): # Iterations refer to minibatches pass
[docs] def on_iteration_end(self): self.check_training_status() self.n_iter += 1
[docs] def on_training_end(self): pass
[docs] def check_training_status(self): """Checks if loss is admissible. If not, training is stopped after max_nans consecutive inadmissible loss loss corresponds to the training loss of the model. `max_nans` is the maximum number of consecutive NaNs after which a ValueError will be """ loss_is_nan = torch.isnan(self.current_loss).item() if loss_is_nan: logger.warning("Model training loss was NaN") self.nan_counter += 1 self.previous_loss_was_nan = True else: self.nan_counter = 0 self.previous_loss_was_nan = False if self.nan_counter >= self.max_nans: raise ValueError( "Loss was NaN {} consecutive times: the model is not training properly. " "Consider using a lower learning rate.".format(self.max_nans) )
@property @abstractmethod def posteriors_loop(self): pass
[docs] def data_loaders_loop(self): """returns an zipped iterable corresponding to loss signature""" data_loaders_loop = [self._posteriors[name] for name in self.posteriors_loop] return zip( data_loaders_loop[0], *[cycle(data_loader) for data_loader in data_loaders_loop[1:]] )
[docs] def register_posterior(self, name, value): name = name.strip("_") self._posteriors[name] = value
[docs] def corrupt_posteriors( self, rate=0.1, corruption="uniform", update_corruption=True ): if not hasattr(self.gene_dataset, "corrupted") and update_corruption: self.gene_dataset.corrupt(rate=rate, corruption=corruption) for name, posterior in self._posteriors.items(): self.register_posterior(name, posterior.corrupted())
[docs] def uncorrupt_posteriors(self): for name_, posterior in self._posteriors.items(): self.register_posterior(name_, posterior.uncorrupted())
def __getattr__(self, name): if "_posteriors" in self.__dict__: _posteriors = self.__dict__["_posteriors"] if name.strip("_") in _posteriors: return _posteriors[name.strip("_")] return object.__getattribute__(self, name) def __delattr__(self, name): if name.strip("_") in self._posteriors: del self._posteriors[name.strip("_")] else: object.__delattr__(self, name) def __setattr__(self, name, value): if isinstance(value, Posterior): name = name.strip("_") self.register_posterior(name, value) else: object.__setattr__(self, name, value)
[docs] def train_test_validation( self, model=None, gene_dataset=None, train_size=0.9, test_size=None, type_class=Posterior, ): """Creates posteriors ``train_set``, ``test_set``, ``validation_set``. If ``train_size + test_size < 1`` then ``validation_set`` is non-empty. Parameters ---------- train_size : float, or None (default is 0.9) test_size : float, or None (default is None) model : (Default value = None) gene_dataset : (Default value = None) type_class : (Default value = Posterior) Returns ------- """ train_size = float(train_size) if train_size > 1.0 or train_size <= 0.0: raise ValueError( "train_size needs to be greater than 0 and less than or equal to 1" ) model = self.model if model is None and hasattr(self, "model") else model gene_dataset = ( self.gene_dataset if gene_dataset is None and hasattr(self, "model") else gene_dataset ) n = len(gene_dataset) try: n_train, n_test = _validate_shuffle_split(n, test_size, train_size) except ValueError: if train_size != 1.0: raise ValueError( "Choice of train_size={} and test_size={} not understood".format( train_size, test_size ) ) n_train, n_test = n, 0 random_state = np.random.RandomState(seed=self.seed) permutation = random_state.permutation(n) indices_test = permutation[:n_test] indices_train = permutation[n_test : (n_test + n_train)] indices_validation = permutation[(n_test + n_train) :] return ( self.create_posterior( model, gene_dataset, indices=indices_train, type_class=type_class ), self.create_posterior( model, gene_dataset, indices=indices_test, type_class=type_class ), self.create_posterior( model, gene_dataset, indices=indices_validation, type_class=type_class ), )
[docs] def create_posterior( self, model=None, gene_dataset=None, shuffle=False, indices=None, type_class=Posterior, ): model = self.model if model is None and hasattr(self, "model") else model gene_dataset = ( self.gene_dataset if gene_dataset is None and hasattr(self, "model") else gene_dataset ) return type_class( model, gene_dataset, shuffle=shuffle, indices=indices, use_cuda=self.use_cuda, data_loader_kwargs=self.data_loader_kwargs, )
class SequentialSubsetSampler(SubsetRandomSampler): def __init__(self, indices): self.indices = np.sort(indices) def __iter__(self): return iter(self.indices) class EarlyStopping: def __init__( self, early_stopping_metric: str = None, save_best_state_metric: str = None, on: str = "test_set", patience: int = 15, threshold: int = 3, benchmark: bool = False, reduce_lr_on_plateau: bool = False, lr_patience: int = 10, lr_factor: float = 0.5, posterior_class=Posterior, ): self.benchmark = benchmark self.patience = patience self.threshold = threshold self.epoch = 0 self.wait = 0 self.wait_lr = 0 self.mode = ( getattr(posterior_class, early_stopping_metric).mode if early_stopping_metric is not None else None ) # We set the best to + inf because we're dealing with a loss we want to minimize self.current_performance = np.inf self.best_performance = np.inf self.best_performance_state = np.inf # If we want to maximize, we start at - inf if self.mode == "max": self.best_performance *= -1 self.current_performance *= -1 self.mode_save_state = ( getattr(Posterior, save_best_state_metric).mode if save_best_state_metric is not None else None ) if self.mode_save_state == "max": self.best_performance_state *= -1 self.early_stopping_metric = early_stopping_metric self.save_best_state_metric = save_best_state_metric self.on = on self.reduce_lr_on_plateau = reduce_lr_on_plateau self.lr_patience = lr_patience self.lr_factor = lr_factor def update(self, scalar): self.epoch += 1 if self.benchmark: continue_training = True reduce_lr = False elif self.wait >= self.patience: continue_training = False reduce_lr = False else: # Check if we should reduce the learning rate if not self.reduce_lr_on_plateau: reduce_lr = False elif self.wait_lr >= self.lr_patience: reduce_lr = True self.wait_lr = 0 else: reduce_lr = False # Shift self.current_performance = scalar # Compute improvement if self.mode == "max": improvement = self.current_performance - self.best_performance elif self.mode == "min": improvement = self.best_performance - self.current_performance else: raise NotImplementedError("Unknown optimization mode") # updating best performance if improvement > 0: self.best_performance = self.current_performance if improvement < self.threshold: self.wait += 1 self.wait_lr += 1 else: self.wait = 0 self.wait_lr = 0 continue_training = True if not continue_training: # FIXME: log total number of epochs run logger.info( "\nStopping early: no improvement of more than " + str(self.threshold) + " nats in " + str(self.patience) + " epochs" ) logger.info( "If the early stopping criterion is too strong, " "please instantiate it with different parameters in the train method." ) return continue_training, reduce_lr def update_state(self, scalar): improved = ( self.mode_save_state == "max" and scalar - self.best_performance_state > 0 ) or ( self.mode_save_state == "min" and self.best_performance_state - scalar > 0 ) if improved: self.best_performance_state = scalar return improved