scvi.train.PyroTrainingPlan#
- class scvi.train.PyroTrainingPlan(pyro_module, loss_fn=None, optim=None, optim_kwargs=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, scale_elbo=1.0)[source]#
Bases:
LowLevelPyroTrainingPlan
Lightning module task to train Pyro scvi-tools modules.
- Parameters:
pyro_module (
PyroBaseModuleClass
) – An instance ofPyroBaseModuleClass
. This object should have callable model and guide attributes or methods.loss_fn (
ELBO
|None
(default:None
)) – A Pyro loss. Should be a subclass ofELBO
. If None, defaults toTrace_ELBO
.optim (
PyroOptim
|None
(default:None
)) – A Pyro optimizer instance, e.g.,Adam
. If None, defaults topyro.optim.Adam
optimizer with a learning rate of 1e-3.optim_kwargs (
dict
|None
(default:None
)) – Keyword arguments for default optimiserpyro.optim.Adam
.n_steps_kl_warmup (
int
|None
(default:None
)) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.n_epochs_kl_warmup (
int
|None
(default:400
)) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.scale_elbo (
float
(default:1.0
)) – Scale ELBO usingscale
. Potentially useful for avoiding numerical inaccuracy when working with very large ELBO.
Attributes table#
Methods table#
|
Called to perform backward on the loss returned in |
Shim optimizer for PyTorch Lightning. |
|
|
Override this method to adjust the default way the |
|
Training step for Pyro training. |
Attributes#
- PyroTrainingPlan.training: bool#
Methods#
- PyroTrainingPlan.backward(*args, **kwargs)[source]#
Called to perform backward on the loss returned in
training_step()
. Override this hook with your own implementation if you need to.- Parameters:
loss – The loss tensor returned by
training_step()
. If gradient accumulation is used, the loss here holds the normalized value (scaled by 1 / accumulation steps).
Example:
def backward(self, loss): loss.backward()
- PyroTrainingPlan.configure_optimizers()[source]#
Shim optimizer for PyTorch Lightning.
PyTorch Lightning wants to take steps on an optimizer returned by this function in order to increment the global step count. See PyTorch Lighinting optimizer manual loop.
Here we provide a shim optimizer that we can take steps on at minimal computational cost in order to keep Lightning happy :).
- PyroTrainingPlan.optimizer_step(*args, **kwargs)[source]#
Override this method to adjust the default way the
Trainer
calls the optimizer.By default, Lightning calls
step()
andzero_grad()
as shown in the example. This method (andzero_grad()
) won’t be called during the accumulation phase whenTrainer(accumulate_grad_batches != 1)
. Overriding this hook has no benefit with manual optimization.- Parameters:
epoch – Current epoch
batch_idx – Index of current batch
optimizer – A PyTorch optimizer
optimizer_closure – The optimizer closure. This closure must be executed as it includes the calls to
training_step()
,optimizer.zero_grad()
, andbackward()
.
Examples:
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): # Add your custom logic to run directly before `optimizer.step()` optimizer.step(closure=optimizer_closure) # Add your custom logic to run directly after `optimizer.step()`