scvi.module.base.PyroBaseModuleClass#
- class scvi.module.base.PyroBaseModuleClass(on_load_kwargs=None)[source]#
Bases:
Module
Base module class for Pyro models.
In Pyro,
model
andguide
should have the same signature. Out of convenience, the forward function of this class passes through to the forward of themodel
.There are two ways this class can be equipped with a model and a guide. First,
model
andguide
can be class attributes that arePyroModule
instances. The implementedmodel
andguide
class method can then return the (private) attributes. Second,model
andguide
methods can be written directly (see Pyro scANVI example) https://pyro.ai/examples/scanvi.html.The
model
andguide
may also be equipped withn_obs
attributes, which can be set toNone
(e.g.,self.n_obs = None
). This attribute may be helpful in designating the size of observation-specific Pyro plates. The value will be updated automatically byPyroTrainingPlan
, provided that it is given the number of training examples upon initialization.
Attributes table#
Model annotation for minibatch training with pyro plate. |
|
Methods table#
|
Creates a |
|
Passthrough to Pyro model. |
|
Callback function run in :method:`~scvi.model.base.BaseModelClass.load`. |
Attributes#
- PyroBaseModuleClass.list_obs_plate_vars[source]#
Model annotation for minibatch training with pyro plate.
A dictionary with: 1. “name” - the name of observation/minibatch plate; 2. “in” - indexes of model args to provide to encoder network when using amortised
inference;
- “sites” - dictionary with
- keys - names of variables that belong to the observation plate (used to recognise
and merge posterior samples for minibatch variables)
- values - the dimensions in non-plate axis of each variable (used to construct output
layer of encoder network when using amortised inference)
- PyroBaseModuleClass.training: bool#
Methods#
- PyroBaseModuleClass.create_predictive(model=None, posterior_samples=None, guide=None, num_samples=None, return_sites=(), parallel=False)[source]#
Creates a
Predictive
object.- Parameters:
model (
Callable
|None
(default:None
)) – Python callable containing Pyro primitives. Defaults toself.model
.posterior_samples (
dict
|None
(default:None
)) – Dictionary of samples from the posteriorguide (
Callable
|None
(default:None
)) – Optional guide to get posterior samples of sites not present inposterior_samples
. Defaults toself.guide
num_samples (
int
|None
(default:None
)) – Number of samples to draw from the predictive distribution. This argument has no effect ifposterior_samples
is non-empty, in which case, the leading dimension size of samples inposterior_samples
is used.return_sites (
tuple
[str
] (default:()
)) – Sites to return; by default only sample sites not present inposterior_samples
are returned.parallel (
bool
(default:False
)) – predict in parallel by wrapping the existing model in an outermostplate
messenger. Note that this requires that the model has all batch dims correctly annotated viaplate
.
- Return type:
- PyroBaseModuleClass.on_load(model)[source]#
Callback function run in :method:`~scvi.model.base.BaseModelClass.load`.
For some Pyro modules with AutoGuides, run one training step prior to loading state dict.