Trainer¶
-
class
scvi.inference.
Trainer
(model, gene_dataset, use_cuda=True, metrics_to_monitor=None, benchmark=False, frequency=None, weight_decay=1e-06, early_stopping_kwargs=None, data_loader_kwargs=None, show_progbar=True, batch_size=128, seed=0, max_nans=10)[source]¶ Bases:
object
The abstract Trainer class for training a PyTorch model and monitoring its statistics.
It should be inherited at least with a
.loss()
function to be optimized in the training loop.- Parameters
model – A model instance from class
VAE
,VAEC
,SCANVI
gene_dataset (
GeneExpressionDataset
GeneExpressionDataset
) – A gene_dataset instance likeCortexDataset()
metrics_to_monitor (
List
,None
Optional
[List
]) – A list of the metrics to monitor. If not specified, will use thedefault_metrics_to_monitor
as specified in each . Default:None
.benchmark (
bool
bool
) – if True, prevents statistics computation in the training. Default:False
.frequency (
int
,None
Optional
[int
]) – The frequency at which to keep track of statistics. Default:None
.early_stopping_metric – The statistics on which to perform early stopping. Default:
None
.save_best_state_metric – The statistics on which we keep the network weights achieving the best store, and restore them at the end of training. Default:
None
.on – The data_loader name reference for the
early_stopping_metric
andsave_best_state_metric
, that should be specified if any of them is. Default:None
.
- Returns
Attributes Summary
Methods Summary
Checks if loss is admissible.
corrupt_posteriors
([rate, corruption, …])create_posterior
([model, gene_dataset, …])returns an zipped iterable corresponding to loss signature
on_training_loop
(tensors_list)register_posterior
(name, value)train
([n_epochs, lr, eps, params])train_test_validation
([model, gene_dataset, …])Creates posteriors
train_set
,test_set
,validation_set
.Place to put extra models in eval mode, etc.
training_extras_init
(**extras_kwargs)Other necessary models to simultaneously train
Attributes Documentation
-
default_metrics_to_monitor
= []¶
-
posteriors_loop
¶
Methods Documentation
-
check_training_status
()[source]¶ Checks if loss is admissible.
If not, training is stopped after max_nans consecutive inadmissible loss loss corresponds to the training loss of the model.
max_nans is the maximum number of consecutive NaNs after which a ValueError will be
-
create_posterior
(model=None, gene_dataset=None, shuffle=False, indices=None, type_class=<class 'scvi.inference.posterior.Posterior'>)[source]¶
-
train_test_validation
(model=None, gene_dataset=None, train_size=0.9, test_size=None, type_class=<class 'scvi.inference.posterior.Posterior'>)[source]¶ Creates posteriors
train_set
,test_set
,validation_set
.If
train_size + test_size < 1
thenvalidation_set
is non-empty.- Parameters
train_size – float, or None (default is 0.9)
test_size – float, or None (default is None)
model – (Default value = None)
gene_dataset – (Default value = None)
type_class – (Default value = Posterior)
- Returns