scvi.train.LowLevelPyroTrainingPlan#
- class scvi.train.LowLevelPyroTrainingPlan(pyro_module, loss_fn=None, optim=None, optim_kwargs=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, scale_elbo=1.0, max_kl_weight=1.0, min_kl_weight=1e-06)[source]#
Bases:
LightningModuleLightning module task to train Pyro scvi-tools modules.
- Parameters:
pyro_module (
PyroBaseModuleClass) – An instance ofPyroBaseModuleClass. This object should have callable model and guide attributes or methods.loss_fn (
ELBO|None(default:None)) – A Pyro loss. Should be a subclass ofELBO. If None, defaults toTrace_ELBO.optim (
Adam|None(default:None)) – A Pytorch optimizer class, e.g.,Adam. If None, defaults totorch.optim.Adam.optim_kwargs (
dict|None(default:None)) – Keyword arguments for optimizer. If None, defaults to dict(lr=1e-3).n_steps_kl_warmup (
int|None(default:None)) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.n_epochs_kl_warmup (
int|None(default:400)) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.scale_elbo (
float(default:1.0)) – Scale ELBO usingscale. Potentially useful for avoiding numerical inaccuracy when working with very large ELBO.
Attributes table#
Scaling factor on KL divergence during training. |
|
Number of training examples. |
|
Methods table#
Configure optimizers for the model. |
|
|
Passthrough to the model's forward method. |
Training epoch end for Pyro training. |
|
|
Training step for Pyro training. |
Attributes#
- LowLevelPyroTrainingPlan.n_obs_training[source]#
Number of training examples.
If not None, updates the n_obs attr of the Pyro module’s model and guide, if they exist.
- LowLevelPyroTrainingPlan.training: bool#