scvi.train.LowLevelPyroTrainingPlan#
- class scvi.train.LowLevelPyroTrainingPlan(pyro_module, loss_fn=None, optim=None, optim_kwargs=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, scale_elbo=1.0)[source]#
Bases:
LightningModule
Lightning module task to train Pyro scvi-tools modules.
- Parameters:
pyro_module (
PyroBaseModuleClass
) – An instance ofPyroBaseModuleClass
. This object should have callable model and guide attributes or methods.loss_fn (
ELBO
|None
(default:None
)) – A Pyro loss. Should be a subclass ofELBO
. If None, defaults toTrace_ELBO
.optim (
Adam
|None
(default:None
)) – A Pytorch optimizer class, e.g.,Adam
. If None, defaults totorch.optim.Adam
.optim_kwargs (
dict
|None
(default:None
)) – Keyword arguments for optimiser. If None, defaults to dict(lr=1e-3).n_steps_kl_warmup (
int
|None
(default:None
)) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.n_epochs_kl_warmup (
int
|None
(default:400
)) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.scale_elbo (
float
(default:1.0
)) – Scale ELBO usingscale
. Potentially useful for avoiding numerical inaccuracy when working with very large ELBO.
Attributes table#
Scaling factor on KL divergence during training. |
|
Number of training examples. |
|
Methods table#
Configure optimizers for the model. |
|
|
Passthrough to the model's forward method. |
Training epoch end for Pyro training. |
|
|
Training step for Pyro training. |
Attributes#
- LowLevelPyroTrainingPlan.n_obs_training[source]#
Number of training examples.
If not None, updates the n_obs attr of the Pyro module’s model and guide, if they exist.
- LowLevelPyroTrainingPlan.training: bool#