class scvi.train.TrainingPlan(module, lr=0.001, weight_decay=1e-06, eps=0.01, optimizer='Adam', n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', lr_min=0, **loss_kwargs)[source]

Bases: pytorch_lightning.core.lightning.LightningModule

Lightning module task to train scvi-tools modules.

module : BaseModuleClassBaseModuleClass

A module instance from class BaseModuleClass.

lr : floatfloat (default: 0.001)

Learning rate used for optimization.

weight_decay : floatfloat (default: 1e-06)

Weight decay used in optimizatoin.

eps : floatfloat (default: 0.01)

eps used for optimization.

optimizer : {‘Adam’, ‘AdamW’}Literal[‘Adam’, ‘AdamW’] (default: 'Adam')

One of “Adam” (Adam), “AdamW” (AdamW).

n_steps_kl_warmup : int | NoneOptional[int] (default: None)

Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

n_epochs_kl_warmup : int | NoneOptional[int] (default: 400)

Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

reduce_lr_on_plateau : boolbool (default: False)

Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.

lr_factor : floatfloat (default: 0.6)

Factor to reduce learning rate.

lr_patience : intint (default: 30)

Number of epochs with no improvement after which learning rate will be reduced.

lr_threshold : floatfloat (default: 0.0)

Threshold for measuring the new optimum.

lr_scheduler_metric : {‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’}Literal[‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’] (default: 'elbo_validation')

Which metric to track for learning rate reduction.

lr_min : floatfloat (default: 0)

Minimum learning rate allowed


Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.



Scaling factor on KL divergence during training.


Number of observations in the training set.



Choose what optimizers and learning-rate schedulers to use in your optimization.

forward(*args, **kwargs)

Passthrough to model.forward().


Called at the end of the training epoch with the outputs of all training steps.

training_step(batch, batch_idx[, optimizer_idx])

Here you compute and return the training loss and some additional metrics for e.g.


Aggregate validation step information.

validation_step(batch, batch_idx)

Operates on a single batch of data from the validation set.