class scvi.train.PyroTrainingPlan(pyro_module, loss_fn=None, optim=None, optim_kwargs=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=400)[source]

Bases: pytorch_lightning.core.lightning.LightningModule

Lightning module task to train Pyro scvi-tools modules.

pyro_module : PyroBaseModuleClassPyroBaseModuleClass

An instance of PyroBaseModuleClass. This object should have callable model and guide attributes or methods.

loss_fn : ELBO | NoneOptional[ELBO] (default: None)

A Pyro loss. Should be a subclass of ELBO. If None, defaults to Trace_ELBO.

optim : PyroOptim | NoneOptional[PyroOptim] (default: None)

A Pyro optimizer instance, e.g., Adam. If None, defaults to pyro.optim.Adam optimizer with a learning rate of 1e-3.

optim_kwargs : dict | NoneOptional[dict] (default: None)

Keyword arguments for default optimiser pyro.optim.Adam.

n_steps_kl_warmup : int | NoneOptional[int] (default: None)

Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

n_epochs_kl_warmup : int | NoneOptional[int] (default: 400)

Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.



Scaling factor on KL divergence during training.


Number of training examples.


backward(*args, **kwargs)

Override backward with your own implementation if you need to.


Choose what optimizers and learning-rate schedulers to use in your optimization.

forward(*args, **kwargs)

Passthrough to model.forward().

optimizer_step(*args, **kwargs)

Override this method to adjust the default way the Trainer calls each optimizer.


Called at the end of the training epoch with the outputs of all training steps.

training_step(batch, batch_idx)

Here you compute and return the training loss and some additional metrics for e.g.