scvi.train.LowLevelPyroTrainingPlan#
- class scvi.train.LowLevelPyroTrainingPlan(pyro_module, loss_fn=None, optim=None, optim_kwargs=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, scale_elbo=1.0)[source]#
Bases:
TunableMixin
,LightningModule
Lightning module task to train Pyro scvi-tools modules.
- Parameters:
pyro_module (PyroBaseModuleClass) – An instance of
PyroBaseModuleClass
. This object should have callablemodel
andguide
attributes or methods.loss_fn (Optional[ELBO]) – A Pyro loss. Should be a subclass of
ELBO
. IfNone
, defaults toTrace_ELBO
.optim (Optional[Adam]) – A Pytorch optimizer class, e.g.,
Adam
. IfNone
, defaults totorch.optim.Adam
.optim_kwargs (Optional[dict]) – Keyword arguments for optimiser. If
None
, defaults todict(lr=1e-3)
.n_steps_kl_warmup (Optional[int]) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when
n_epochs_kl_warmup
is set to None.n_epochs_kl_warmup (Optional[int]) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides
n_steps_kl_warmup
when both are notNone
.scale_elbo (float) – Scale ELBO using
scale
. Potentially useful for avoiding numerical inaccuracy when working with very large ELBO.
Attributes table#
Scaling factor on KL divergence during training. |
|
Number of training examples. |
Methods table#
Configure optimizers for the model. |
|
|
Passthrough to the model's forward method. |
|
Training epoch end for Pyro training. |
|
Training step for Pyro training. |
Attributes#
kl_weight
n_obs_training
- LowLevelPyroTrainingPlan.n_obs_training[source]#
Number of training examples.
If not
None
, updates then_obs
attr of the Pyro module’smodel
andguide
, if they exist.
training
Methods#
configure_optimizers
forward
- LowLevelPyroTrainingPlan.forward(*args, **kwargs)[source]#
Passthrough to the model’s forward method.
training_epoch_end
training_step