scvi.train.TrainingPlan#

class scvi.train.TrainingPlan(module, *, optimizer='Adam', optimizer_creator=None, lr=0.001, weight_decay=1e-06, eps=0.01, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', lr_min=0, max_kl_weight=1.0, min_kl_weight=0.0, **loss_kwargs)[source]#

Bases: TunableMixin, LightningModule

Lightning module task to train scvi-tools modules.

The training plan is a PyTorch Lightning Module that is initialized with a scvi-tools module object. It configures the optimizers, defines the training step and validation step, and computes metrics to be recorded during training. The training step and validation step are functions that take data, run it through the model and return the loss, which will then be used to optimize the model parameters in the Trainer. Overall, custom training plans can be used to develop complex inference schemes on top of modules.

The following developer tutorial will familiarize you more with training plans and how to use them: Constructing a high-level model.

Parameters:
  • module (BaseModuleClass) – A module instance from class BaseModuleClass.

  • optimizer (Tunable_[Literal['Adam', 'AdamW', 'Custom']]) – One of “Adam” (Adam), “AdamW” (AdamW), or “Custom”, which requires a custom optimizer creator callable to be passed via optimizer_creator.

  • optimizer_creator (Optional[Callable[[Iterable[Tensor]], Optimizer]]) – A callable taking in parameters and returning a Optimizer. This allows using any PyTorch optimizer with custom hyperparameters.

  • lr (Tunable_[float]) – Learning rate used for optimization, when optimizer_creator is None.

  • weight_decay (Tunable_[float]) – Weight decay used in optimization, when optimizer_creator is None.

  • eps (Tunable_[float]) – eps used for optimization, when optimizer_creator is None.

  • n_steps_kl_warmup (Tunable_[int]) – Number of training steps (minibatches) to scale weight on KL divergences from min_kl_weight to max_kl_weight. Only activated when n_epochs_kl_warmup is set to None.

  • n_epochs_kl_warmup (Tunable_[int]) – Number of epochs to scale weight on KL divergences from min_kl_weight to max_kl_weight. Overrides n_steps_kl_warmup when both are not None.

  • reduce_lr_on_plateau (Tunable_[bool]) – Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.

  • lr_factor (Tunable_[float]) – Factor to reduce learning rate.

  • lr_patience (Tunable_[int]) – Number of epochs with no improvement after which learning rate will be reduced.

  • lr_threshold (Tunable_[float]) – Threshold for measuring the new optimum.

  • lr_scheduler_metric (Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation']) – Which metric to track for learning rate reduction.

  • lr_min (Tunable_[float]) – Minimum learning rate allowed.

  • max_kl_weight (Tunable_[float]) – Maximum scaling factor on KL divergence during training.

  • min_kl_weight (Tunable_[float]) – Minimum scaling factor on KL divergence during training.

  • **loss_kwargs – Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.

Attributes table#

kl_weight

Scaling factor on KL divergence during training.

n_obs_training

Number of observations in the training set.

n_obs_validation

Number of observations in the validation set.

Methods table#

compute_and_log_metrics(loss_output, ...)

Computes and logs metrics.

configure_optimizers()

Configure optimizers for the model.

forward(*args, **kwargs)

Passthrough to the module's forward method.

get_optimizer_creator()

Get optimizer creator for the model.

initialize_train_metrics()

Initialize train related metrics.

initialize_val_metrics()

Initialize val related metrics.

training_step(batch, batch_idx[, optimizer_idx])

Training step for the model.

validation_step(batch, batch_idx)

Validation step for the model.

Attributes#

kl_weight

TrainingPlan.kl_weight[source]#

Scaling factor on KL divergence during training.

n_obs_training

TrainingPlan.n_obs_training[source]#

Number of observations in the training set.

This will update the loss kwargs for loss rescaling.

Notes

This can get set after initialization

n_obs_validation

TrainingPlan.n_obs_validation[source]#

Number of observations in the validation set.

This will update the loss kwargs for loss rescaling.

Notes

This can get set after initialization

training

TrainingPlan.training: bool#

Methods#

compute_and_log_metrics

TrainingPlan.compute_and_log_metrics(loss_output, metrics, mode)[source]#

Computes and logs metrics.

Parameters:
  • loss_output (LossOutput) – LossOutput object from scvi-tools module

  • metrics (Dict[str, ElboMetric]) – Dictionary of metrics to update

  • mode (str) – Postfix string to add to the metric name of extra metrics

configure_optimizers

TrainingPlan.configure_optimizers()[source]#

Configure optimizers for the model.

forward

TrainingPlan.forward(*args, **kwargs)[source]#

Passthrough to the module’s forward method.

get_optimizer_creator

TrainingPlan.get_optimizer_creator()[source]#

Get optimizer creator for the model.

initialize_train_metrics

TrainingPlan.initialize_train_metrics()[source]#

Initialize train related metrics.

initialize_val_metrics

TrainingPlan.initialize_val_metrics()[source]#

Initialize val related metrics.

training_step

TrainingPlan.training_step(batch, batch_idx, optimizer_idx=0)[source]#

Training step for the model.

validation_step

TrainingPlan.validation_step(batch, batch_idx)[source]#

Validation step for the model.