scvi.model.base.PyroSampleMixin

scvi.model.base.PyroSampleMixin#

class scvi.model.base.PyroSampleMixin[source]#

Mixin class for generating samples from posterior distribution.

Works using both minibatches and full data.

Methods table#

sample_posterior([num_samples, ...])

Summarise posterior distribution.

Methods#

PyroSampleMixin.sample_posterior(num_samples=1000, return_sites=None, accelerator='auto', device='auto', batch_size=None, return_observed=False, return_samples=False, summary_fun=None)[source]#

Summarise posterior distribution.

Generate samples from posterior distribution for each parameter and compute mean, 5th/95th quantiles, standard deviation.

Parameters:
  • num_samples (int (default: 1000)) – Number of posterior samples to generate.

  • return_sites (list | None (default: None)) – List of variables for which to generate posterior samples, defaults to all variables.

  • accelerator (str (default: 'auto')) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.

  • device (int | str (default: 'auto')) – The device to use. Can be set to a non-negative index (int or str) or “auto” for automatic selection based on the chosen accelerator. If set to “auto” and accelerator is not determined to be “cpu”, then device will be set to the first available device.

  • batch_size (int | None (default: None)) – Minibatch size for data loading into model. Defaults to scvi.settings.batch_size.

  • return_observed (bool (default: False)) – Return observed sites/variables? Observed count matrix can be very large so not returned by default.

  • return_samples (bool (default: False)) – Return all generated posterior samples in addition to sample mean, 5th/95th quantile and SD?

  • summary_fun (dict[str, Callable] | None (default: None)) – a dict in the form {“means”: np.mean, “std”: np.std} which specifies posterior distribution summaries to compute and which names to use. See below for default returns.

Returns:

post_sample_means: Dict[str, np.ndarray]

Mean of the posterior distribution for each variable, a dictionary of numpy arrays for each variable;

post_sample_q05: Dict[str, np.ndarray]

5th quantile of the posterior distribution for each variable;

post_sample_q05: Dict[str, np.ndarray]

95th quantile of the posterior distribution for each variable;

post_sample_q05: Dict[str, np.ndarray]

Standard deviation of the posterior distribution for each variable;

posterior_samples: Optional[Dict[str, np.ndarray]]

Posterior distribution samples for each variable as numpy arrays of shape (n_samples, …) (Optional).

Notes

Note for developers: requires overwritten list_obs_plate_vars property, which lists observation/minibatch plate name and variables. See list_obs_plate_vars for details of the variables it should contain. This dictionary can be returned by model class property self.module.model.list_obs_plate_vars to keep all model-specific variables in one place.