scvi.train.TrainingPlan

class scvi.train.TrainingPlan(module, lr=0.001, weight_decay=1e-06, eps=0.01, optimizer='Adam', n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', lr_min=0, **loss_kwargs)[source]

Bases: pytorch_lightning.core.lightning.LightningModule

Lightning module task to train scvi-tools modules.

The training plan is a PyTorch Lightning Module that is initialized with a scvi-tools module object. It configures the optimizers, defines the training step and validation step, and computes metrics to be recorded during training. The training step and validation step are functions that take data, run it through the model and return the loss, which will then be used to optimize the model parameters in the Trainer. Overall, custom training plans can be used to develop complex inference schemes on top of modules. The following developer tutorial will familiarize you more with training plans and how to use them: Constructing a high-level model.

Parameters
module : BaseModuleClassBaseModuleClass

A module instance from class BaseModuleClass.

lr : floatfloat (default: 0.001)

Learning rate used for optimization.

weight_decay : floatfloat (default: 1e-06)

Weight decay used in optimizatoin.

eps : floatfloat (default: 0.01)

eps used for optimization.

optimizer : {‘Adam’, ‘AdamW’}Literal[‘Adam’, ‘AdamW’] (default: 'Adam')

One of “Adam” (Adam), “AdamW” (AdamW).

n_steps_kl_warmup : int | NoneOptional[int] (default: None)

Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

n_epochs_kl_warmup : int | NoneOptional[int] (default: 400)

Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

reduce_lr_on_plateau : boolbool (default: False)

Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.

lr_factor : floatfloat (default: 0.6)

Factor to reduce learning rate.

lr_patience : intint (default: 30)

Number of epochs with no improvement after which learning rate will be reduced.

lr_threshold : floatfloat (default: 0.0)

Threshold for measuring the new optimum.

lr_scheduler_metric : {‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’}Literal[‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’] (default: 'elbo_validation')

Which metric to track for learning rate reduction.

lr_min : floatfloat (default: 0)

Minimum learning rate allowed

**loss_kwargs

Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.

Attributes

kl_weight

Scaling factor on KL divergence during training.

n_obs_training

Number of observations in the training set.

Methods

configure_optimizers()

Choose what optimizers and learning-rate schedulers to use in your optimization.

forward(*args, **kwargs)

Passthrough to model.forward().

training_epoch_end(outputs)

Called at the end of the training epoch with the outputs of all training steps.

training_step(batch, batch_idx[, optimizer_idx])

Here you compute and return the training loss and some additional metrics for e.g.

validation_epoch_end(outputs)

Aggregate validation step information.

validation_step(batch, batch_idx)

Operates on a single batch of data from the validation set.