scvi.module.base.PyroBaseModuleClass#

class scvi.module.base.PyroBaseModuleClass(on_load_kwargs=None)[source]#

Bases: Module

Base module class for Pyro models.

In Pyro, model and guide should have the same signature. Out of convenience, the forward function of this class passes through to the forward of the model.

There are two ways this class can be equipped with a model and a guide. First, model and guide can be class attributes that are PyroModule instances. The implemented model and guide class method can then return the (private) attributes. Second, model and guide methods can be written directly (see Pyro scANVI example) https://pyro.ai/examples/scanvi.html.

The model and guide may also be equipped with n_obs attributes, which can be set to None (e.g., self.n_obs = None). This attribute may be helpful in designating the size of observation-specific Pyro plates. The value will be updated automatically by PyroTrainingPlan, provided that it is given the number of training examples upon initialization.

Parameters:
on_load_kwargs : dict | NoneOptional[dict] (default: None)

Dictionary containing keyword args to use in self.on_load.

Attributes table#

guide

list_obs_plate_vars

Model annotation for minibatch training with pyro plate.

model

Methods table#

create_predictive([model, ...])

Creates a Predictive object.

forward(*args, **kwargs)

Passthrough to Pyro model.

on_load(model)

Callback function run in :method:`~scvi.model.base.BaseModelClass.load` prior to loading module state dict.

Attributes#

T_destination#

PyroBaseModuleClass.T_destination#

alias of TypeVar(‘T_destination’, bound=Dict[str, Any])

alias of TypeVar(‘T_destination’, bound=Dict[str, Any]) .. autoattribute:: PyroBaseModuleClass.T_destination dump_patches ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

PyroBaseModuleClass.dump_patches: bool = False#

guide#

PyroBaseModuleClass.guide[source]#

list_obs_plate_vars#

PyroBaseModuleClass.list_obs_plate_vars[source]#

Model annotation for minibatch training with pyro plate.

A dictionary with: 1. “name” - the name of observation/minibatch plate; 2. “in” - indexes of model args to provide to encoder network when using amortised inference; 3. “sites” - dictionary with

keys - names of variables that belong to the observation plate (used to recognise

and merge posterior samples for minibatch variables)

values - the dimensions in non-plate axis of each variable (used to construct output

layer of encoder network when using amortised inference)

model#

PyroBaseModuleClass.model[source]#

training#

PyroBaseModuleClass.training: bool#

Methods#

create_predictive#

PyroBaseModuleClass.create_predictive(model=None, posterior_samples=None, guide=None, num_samples=None, return_sites=(), parallel=False)[source]#

Creates a Predictive object.

Parameters:
model : Callable | NoneOptional[Callable] (default: None)

Python callable containing Pyro primitives. Defaults to self.model.

posterior_samples : dict | NoneOptional[dict] (default: None)

Dictionary of samples from the posterior

guide : Callable | NoneOptional[Callable] (default: None)

Optional guide to get posterior samples of sites not present in posterior_samples. Defaults to self.guide

num_samples : int | NoneOptional[int] (default: None)

Number of samples to draw from the predictive distribution. This argument has no effect if posterior_samples is non-empty, in which case, the leading dimension size of samples in posterior_samples is used.

return_sites : Tuple[str] (default: ())

Sites to return; by default only sample sites not present in posterior_samples are returned.

parallel : bool (default: False)

predict in parallel by wrapping the existing model in an outermost plate messenger. Note that this requires that the model has all batch dims correctly annotated via plate.

Return type:

Predictive

forward#

PyroBaseModuleClass.forward(*args, **kwargs)[source]#

Passthrough to Pyro model.

on_load#

PyroBaseModuleClass.on_load(model)[source]#

Callback function run in :method:`~scvi.model.base.BaseModelClass.load` prior to loading module state dict.

For some Pyro modules with AutoGuides, run one training step prior to loading state dict.