scvi.train.LowLevelPyroTrainingPlan#

class scvi.train.LowLevelPyroTrainingPlan(pyro_module, loss_fn=None, optim=None, optim_kwargs=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, scale_elbo=1.0)[source]#

Bases: TunableMixin, LightningModule

Lightning module task to train Pyro scvi-tools modules.

Parameters:
  • pyro_module (PyroBaseModuleClass) – An instance of PyroBaseModuleClass. This object should have callable model and guide attributes or methods.

  • loss_fn (Optional[ELBO]) – A Pyro loss. Should be a subclass of ELBO. If None, defaults to Trace_ELBO.

  • optim (Optional[Adam]) – A Pytorch optimizer class, e.g., Adam. If None, defaults to torch.optim.Adam.

  • optim_kwargs (Optional[dict]) – Keyword arguments for optimiser. If None, defaults to dict(lr=1e-3).

  • n_steps_kl_warmup (Optional[int]) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

  • n_epochs_kl_warmup (Optional[int]) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

  • scale_elbo (float) – Scale ELBO using scale. Potentially useful for avoiding numerical inaccuracy when working with very large ELBO.

Attributes table#

kl_weight

Scaling factor on KL divergence during training.

n_obs_training

Number of training examples.

Methods table#

configure_optimizers()

Configure optimizers for the model.

forward(*args, **kwargs)

Passthrough to the model's forward method.

training_epoch_end(outputs)

Training epoch end for Pyro training.

training_step(batch, batch_idx)

Training step for Pyro training.

Attributes#

kl_weight

LowLevelPyroTrainingPlan.kl_weight[source]#

Scaling factor on KL divergence during training.

n_obs_training

LowLevelPyroTrainingPlan.n_obs_training[source]#

Number of training examples.

If not None, updates the n_obs attr of the Pyro module’s model and guide, if they exist.

training

LowLevelPyroTrainingPlan.training: bool#

Methods#

configure_optimizers

LowLevelPyroTrainingPlan.configure_optimizers()[source]#

Configure optimizers for the model.

forward

LowLevelPyroTrainingPlan.forward(*args, **kwargs)[source]#

Passthrough to the model’s forward method.

training_epoch_end

LowLevelPyroTrainingPlan.training_epoch_end(outputs)[source]#

Training epoch end for Pyro training.

training_step

LowLevelPyroTrainingPlan.training_step(batch, batch_idx)[source]#

Training step for Pyro training.