scvi.train.TrainingPlan#
- class scvi.train.TrainingPlan(module, *, optimizer='Adam', optimizer_creator=None, lr=0.001, weight_decay=1e-06, eps=0.01, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', lr_min=0, max_kl_weight=1.0, min_kl_weight=0.0, **loss_kwargs)[source]#
Bases:
TunableMixin
,LightningModule
Lightning module task to train scvi-tools modules.
The training plan is a PyTorch Lightning Module that is initialized with a scvi-tools module object. It configures the optimizers, defines the training step and validation step, and computes metrics to be recorded during training. The training step and validation step are functions that take data, run it through the model and return the loss, which will then be used to optimize the model parameters in the Trainer. Overall, custom training plans can be used to develop complex inference schemes on top of modules.
The following developer tutorial will familiarize you more with training plans and how to use them: Constructing a high-level model.
- Parameters:
module (BaseModuleClass) – A module instance from class
BaseModuleClass
.optimizer (Tunable_[Literal['Adam', 'AdamW', 'Custom']]) – One of “Adam” (
Adam
), “AdamW” (AdamW
), or “Custom”, which requires a custom optimizer creator callable to be passed viaoptimizer_creator
.optimizer_creator (Optional[Callable[[Iterable[Tensor]], Optimizer]]) – A callable taking in parameters and returning a
Optimizer
. This allows using any PyTorch optimizer with custom hyperparameters.lr (Tunable_[float]) – Learning rate used for optimization, when
optimizer_creator
is None.weight_decay (Tunable_[float]) – Weight decay used in optimization, when
optimizer_creator
is None.eps (Tunable_[float]) – eps used for optimization, when
optimizer_creator
is None.n_steps_kl_warmup (Tunable_[int]) – Number of training steps (minibatches) to scale weight on KL divergences from
min_kl_weight
tomax_kl_weight
. Only activated whenn_epochs_kl_warmup
is set to None.n_epochs_kl_warmup (Tunable_[int]) – Number of epochs to scale weight on KL divergences from
min_kl_weight
tomax_kl_weight
. Overridesn_steps_kl_warmup
when both are notNone
.reduce_lr_on_plateau (Tunable_[bool]) – Whether to monitor validation loss and reduce learning rate when validation set
lr_scheduler_metric
plateaus.lr_factor (Tunable_[float]) – Factor to reduce learning rate.
lr_patience (Tunable_[int]) – Number of epochs with no improvement after which learning rate will be reduced.
lr_threshold (Tunable_[float]) – Threshold for measuring the new optimum.
lr_scheduler_metric (Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation']) – Which metric to track for learning rate reduction.
lr_min (Tunable_[float]) – Minimum learning rate allowed.
max_kl_weight (Tunable_[float]) – Maximum scaling factor on KL divergence during training.
min_kl_weight (Tunable_[float]) – Minimum scaling factor on KL divergence during training.
**loss_kwargs – Keyword args to pass to the loss method of the
module
.kl_weight
should not be passed here and is handled automatically.
Attributes table#
Scaling factor on KL divergence during training. |
|
Number of observations in the training set. |
|
Number of observations in the validation set. |
Methods table#
|
Computes and logs metrics. |
Configure optimizers for the model. |
|
|
Passthrough to the module's forward method. |
Get optimizer creator for the model. |
|
Initialize train related metrics. |
|
Initialize val related metrics. |
|
|
Training step for the model. |
|
Validation step for the model. |
Attributes#
kl_weight
n_obs_training
- TrainingPlan.n_obs_training[source]#
Number of observations in the training set.
This will update the loss kwargs for loss rescaling.
Notes
This can get set after initialization
n_obs_validation
- TrainingPlan.n_obs_validation[source]#
Number of observations in the validation set.
This will update the loss kwargs for loss rescaling.
Notes
This can get set after initialization
training
Methods#
compute_and_log_metrics
- TrainingPlan.compute_and_log_metrics(loss_output, metrics, mode)[source]#
Computes and logs metrics.
- Parameters:
loss_output (LossOutput) – LossOutput object from scvi-tools module
metrics (Dict[str, ElboMetric]) – Dictionary of metrics to update
mode (str) – Postfix string to add to the metric name of extra metrics
configure_optimizers
forward
get_optimizer_creator
initialize_train_metrics
initialize_val_metrics
training_step
validation_step