scvi.train.TrainingPlan#

class scvi.train.TrainingPlan(module, *, optimizer='Adam', optimizer_creator=None, lr=0.001, weight_decay=1e-06, eps=0.01, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', lr_min=0, max_kl_weight=1.0, min_kl_weight=0.0, **loss_kwargs)[source]#

Bases: TunableMixin, LightningModule

Lightning module task to train scvi-tools modules.

The training plan is a PyTorch Lightning Module that is initialized with a scvi-tools module object. It configures the optimizers, defines the training step and validation step, and computes metrics to be recorded during training. The training step and validation step are functions that take data, run it through the model and return the loss, which will then be used to optimize the model parameters in the Trainer. Overall, custom training plans can be used to develop complex inference schemes on top of modules.

The following developer tutorial will familiarize you more with training plans and how to use them: Constructing a high-level model.

Parameters:
  • module (BaseModuleClass) – A module instance from class BaseModuleClass.

  • optimizer (Tunable_[Literal['Adam', 'AdamW', 'Custom']] (default: 'Adam')) – One of “Adam” (Adam), “AdamW” (AdamW), or “Custom”, which requires a custom optimizer creator callable to be passed via optimizer_creator.

  • optimizer_creator (Optional[Callable[[Iterable[Tensor]], Optimizer]] (default: None)) – A callable taking in parameters and returning a Optimizer. This allows using any PyTorch optimizer with custom hyperparameters.

  • lr (Tunable_[float] (default: 0.001)) – Learning rate used for optimization, when optimizer_creator is None.

  • weight_decay (Tunable_[float] (default: 1e-06)) – Weight decay used in optimization, when optimizer_creator is None.

  • eps (Tunable_[float] (default: 0.01)) – eps used for optimization, when optimizer_creator is None.

  • n_steps_kl_warmup (Tunable_[int] (default: None)) – Number of training steps (minibatches) to scale weight on KL divergences from min_kl_weight to max_kl_weight. Only activated when n_epochs_kl_warmup is set to None.

  • n_epochs_kl_warmup (Tunable_[int] (default: 400)) – Number of epochs to scale weight on KL divergences from min_kl_weight to max_kl_weight. Overrides n_steps_kl_warmup when both are not None.

  • reduce_lr_on_plateau (Tunable_[bool] (default: False)) – Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.

  • lr_factor (Tunable_[float] (default: 0.6)) – Factor to reduce learning rate.

  • lr_patience (Tunable_[int] (default: 30)) – Number of epochs with no improvement after which learning rate will be reduced.

  • lr_threshold (Tunable_[float] (default: 0.0)) – Threshold for measuring the new optimum.

  • lr_scheduler_metric (Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation'] (default: 'elbo_validation')) – Which metric to track for learning rate reduction.

  • lr_min (Tunable_[float] (default: 0)) – Minimum learning rate allowed.

  • max_kl_weight (Tunable_[float] (default: 1.0)) – Maximum scaling factor on KL divergence during training.

  • min_kl_weight (Tunable_[float] (default: 0.0)) – Minimum scaling factor on KL divergence during training.

  • **loss_kwargs – Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.

Attributes table#

kl_weight

Scaling factor on KL divergence during training.

n_obs_training

Number of observations in the training set.

n_obs_validation

Number of observations in the validation set.

use_sync_dist

training

Methods table#

compute_and_log_metrics(loss_output, ...)

Computes and logs metrics.

configure_optimizers()

Configure optimizers for the model.

forward(*args, **kwargs)

Passthrough to the module's forward method.

get_optimizer_creator()

Get optimizer creator for the model.

initialize_train_metrics()

Initialize train related metrics.

initialize_val_metrics()

Initialize val related metrics.

training_step(batch, batch_idx)

Training step for the model.

validation_step(batch, batch_idx)

Validation step for the model.

Attributes#

TrainingPlan.kl_weight[source]#

Scaling factor on KL divergence during training.

TrainingPlan.n_obs_training[source]#

Number of observations in the training set.

This will update the loss kwargs for loss rescaling.

Notes

This can get set after initialization

TrainingPlan.n_obs_validation[source]#

Number of observations in the validation set.

This will update the loss kwargs for loss rescaling.

Notes

This can get set after initialization

TrainingPlan.use_sync_dist[source]#
TrainingPlan.training: bool#

Methods#

TrainingPlan.compute_and_log_metrics(loss_output, metrics, mode)[source]#

Computes and logs metrics.

Parameters:
  • loss_output (LossOutput) – LossOutput object from scvi-tools module

  • metrics (dict[str, ElboMetric]) – Dictionary of metrics to update

  • mode (str) – Postfix string to add to the metric name of extra metrics

TrainingPlan.configure_optimizers()[source]#

Configure optimizers for the model.

TrainingPlan.forward(*args, **kwargs)[source]#

Passthrough to the module’s forward method.

TrainingPlan.get_optimizer_creator()[source]#

Get optimizer creator for the model.

TrainingPlan.initialize_train_metrics()[source]#

Initialize train related metrics.

TrainingPlan.initialize_val_metrics()[source]#

Initialize val related metrics.

TrainingPlan.training_step(batch, batch_idx)[source]#

Training step for the model.

TrainingPlan.validation_step(batch, batch_idx)[source]#

Validation step for the model.