scvi.train.TrainingPlan#
- class scvi.train.TrainingPlan(module, *, optimizer='Adam', optimizer_creator=None, lr=0.001, weight_decay=1e-06, eps=0.01, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', lr_min=0, max_kl_weight=1.0, min_kl_weight=0.0, **loss_kwargs)[source]#
Bases:
LightningModule
Lightning module task to train scvi-tools modules.
The training plan is a PyTorch Lightning Module that is initialized with a scvi-tools module object. It configures the optimizers, defines the training step and validation step, and computes metrics to be recorded during training. The training step and validation step are functions that take data, run it through the model and return the loss, which will then be used to optimize the model parameters in the Trainer. Overall, custom training plans can be used to develop complex inference schemes on top of modules.
The following developer tutorial will familiarize you more with training plans and how to use them: Constructing a high-level model.
- Parameters:
module (
BaseModuleClass
) – A module instance from classBaseModuleClass
.optimizer (
Literal
['Adam'
,'AdamW'
,'Custom'
] (default:'Adam'
)) – One of “Adam” (Adam
), “AdamW” (AdamW
), or “Custom”, which requires a custom optimizer creator callable to be passed via optimizer_creator.optimizer_creator (
Callable
[[Iterable
[Tensor
]],Optimizer
] |None
(default:None
)) – A callable taking in parameters and returning aOptimizer
. This allows using any PyTorch optimizer with custom hyperparameters.lr (
float
(default:0.001
)) – Learning rate used for optimization, when optimizer_creator is None.weight_decay (
float
(default:1e-06
)) – Weight decay used in optimization, when optimizer_creator is None.eps (
float
(default:0.01
)) – eps used for optimization, when optimizer_creator is None.n_steps_kl_warmup (
int
(default:None
)) – Number of training steps (minibatches) to scale weight on KL divergences from min_kl_weight to max_kl_weight. Only activated when n_epochs_kl_warmup is set to None.n_epochs_kl_warmup (
int
(default:400
)) – Number of epochs to scale weight on KL divergences from min_kl_weight to max_kl_weight. Overrides n_steps_kl_warmup when both are not None.reduce_lr_on_plateau (
bool
(default:False
)) – Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.lr_factor (
float
(default:0.6
)) – Factor to reduce learning rate.lr_patience (
int
(default:30
)) – Number of epochs with no improvement after which learning rate will be reduced.lr_threshold (
float
(default:0.0
)) – Threshold for measuring the new optimum.lr_scheduler_metric (
Literal
['elbo_validation'
,'reconstruction_loss_validation'
,'kl_local_validation'
] (default:'elbo_validation'
)) – Which metric to track for learning rate reduction.lr_min (
float
(default:0
)) – Minimum learning rate allowed.max_kl_weight (
float
(default:1.0
)) – Maximum scaling factor on KL divergence during training.min_kl_weight (
float
(default:0.0
)) – Minimum scaling factor on KL divergence during training.**loss_kwargs – Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.
Attributes table#
Scaling factor on KL divergence during training. |
|
Number of observations in the training set. |
|
Number of observations in the validation set. |
|
Methods table#
|
Computes and logs metrics. |
Configure optimizers for the model. |
|
|
Passthrough to the module's forward method. |
Get optimizer creator for the model. |
|
Initialize train related metrics. |
|
Initialize val related metrics. |
|
|
Training step for the model. |
|
Validation step for the model. |
Attributes#
- TrainingPlan.n_obs_training[source]#
Number of observations in the training set.
This will update the loss kwargs for loss rescaling.
Notes
This can get set after initialization
- TrainingPlan.n_obs_validation[source]#
Number of observations in the validation set.
This will update the loss kwargs for loss rescaling.
Notes
This can get set after initialization
- TrainingPlan.training: bool#
Methods#
- TrainingPlan.compute_and_log_metrics(loss_output, metrics, mode)[source]#
Computes and logs metrics.
- Parameters:
loss_output (
LossOutput
) – LossOutput object from scvi-tools modulemetrics (
dict
[str
,ElboMetric
]) – Dictionary of metrics to updatemode (
str
) – Postfix string to add to the metric name of extra metrics