scvi.train.PyroTrainingPlan

class scvi.train.PyroTrainingPlan(pyro_module, loss_fn=None, optim=None)[source]

Bases: pytorch_lightning.core.lightning.LightningModule

Lightning module task to train Pyro scvi-tools modules.

Parameters
pyro_module : PyroBaseModuleClassPyroBaseModuleClass

An instance of PyroBaseModuleClass. This object should have callable model and guide attributes or methods.

loss_fn : ELBO | NoneOptional[ELBO] (default: None)

A Pyro loss. Should be a subclass of ELBO. If None, defaults to Trace_ELBO.

optim : PyroOptim | NoneOptional[PyroOptim] (default: None)

A Pyro optimizer, e.g., Adam. If None, defaults to Adam optimizer with a learning rate of 1e-3.

Methods

backward(*args, **kwargs)

Override backward with your own implementation if you need to.

configure_optimizers()

Choose what optimizers and learning-rate schedulers to use in your optimization.

forward(*args, **kwargs)

Passthrough to model.forward().

optimizer_step(*args, **kwargs)

Override this method to adjust the default way the Trainer calls each optimizer.

training_epoch_end(outputs)

Called at the end of the training epoch with the outputs of all training steps.

training_step(batch, batch_idx)

Here you compute and return the training loss and some additional metrics for e.g.