Trainer

class scvi.inference.Trainer(model, gene_dataset, use_cuda=True, metrics_to_monitor=None, benchmark=False, frequency=None, weight_decay=1e-06, early_stopping_kwargs=None, data_loader_kwargs=None, show_progbar=True, batch_size=128, seed=0, max_nans=10)[source]

Bases: object

The abstract Trainer class for training a PyTorch model and monitoring its statistics.

It should be inherited at least with a .loss() function to be optimized in the training loop.

Parameters
  • model – A model instance from class VAE, VAEC, SCANVI

  • gene_dataset (GeneExpressionDatasetGeneExpressionDataset) – A gene_dataset instance like CortexDataset()

  • use_cuda (boolbool) – Default: True.

  • metrics_to_monitor (List, NoneOptional[List]) – A list of the metrics to monitor. If not specified, will use the default_metrics_to_monitor as specified in each . Default: None.

  • benchmark (boolbool) – if True, prevents statistics computation in the training. Default: False.

  • frequency (int, NoneOptional[int]) – The frequency at which to keep track of statistics. Default: None.

  • early_stopping_metric – The statistics on which to perform early stopping. Default: None.

  • save_best_state_metric – The statistics on which we keep the network weights achieving the best store, and restore them at the end of training. Default: None.

  • on – The data_loader name reference for the early_stopping_metric and save_best_state_metric, that should be specified if any of them is. Default: None.

  • show_progbar (boolbool) – If False, disables progress bar.

  • seed (intint) – Random seed for train/test/validate split

Returns

Attributes Summary

default_metrics_to_monitor

posteriors_loop

Methods Summary

check_training_status()

Checks if loss is admissible.

compute_metrics()

corrupt_posteriors([rate, corruption, …])

create_posterior([model, gene_dataset, …])

data_loaders_loop()

returns an zipped iterable corresponding to loss signature

on_epoch_begin()

on_epoch_end()

on_iteration_begin()

on_iteration_end()

on_training_begin()

on_training_end()

on_training_loop(tensors_list)

register_posterior(name, value)

train([n_epochs, lr, eps, params])

train_test_validation([model, gene_dataset, …])

Creates posteriors train_set, test_set, validation_set.

training_extras_end()

Place to put extra models in eval mode, etc.

training_extras_init(**extras_kwargs)

Other necessary models to simultaneously train

uncorrupt_posteriors()

Attributes Documentation

default_metrics_to_monitor = []
posteriors_loop

Methods Documentation

check_training_status()[source]

Checks if loss is admissible.

If not, training is stopped after max_nans consecutive inadmissible loss loss corresponds to the training loss of the model.

max_nans is the maximum number of consecutive NaNs after which a ValueError will be

compute_metrics()[source]
corrupt_posteriors(rate=0.1, corruption='uniform', update_corruption=True)[source]
create_posterior(model=None, gene_dataset=None, shuffle=False, indices=None, type_class=<class 'scvi.inference.posterior.Posterior'>)[source]
data_loaders_loop()[source]

returns an zipped iterable corresponding to loss signature

on_epoch_begin()[source]
on_epoch_end()[source]
on_iteration_begin()[source]
on_iteration_end()[source]
on_training_begin()[source]
on_training_end()[source]
on_training_loop(tensors_list)[source]
register_posterior(name, value)[source]
train(n_epochs=400, lr=0.001, eps=0.01, params=None, **extras_kwargs)[source]
train_test_validation(model=None, gene_dataset=None, train_size=0.9, test_size=None, type_class=<class 'scvi.inference.posterior.Posterior'>)[source]

Creates posteriors train_set, test_set, validation_set.

If train_size + test_size < 1 then validation_set is non-empty.

Parameters
  • train_size – float, or None (default is 0.9)

  • test_size – float, or None (default is None)

  • model – (Default value = None)

  • gene_dataset – (Default value = None)

  • type_class – (Default value = Posterior)

Returns

training_extras_end()[source]

Place to put extra models in eval mode, etc.

training_extras_init(**extras_kwargs)[source]

Other necessary models to simultaneously train

Parameters

**extras_kwargs

Returns

uncorrupt_posteriors()[source]