scvi.model.base.PyroSampleMixin#

class scvi.model.base.PyroSampleMixin[source]#

Mixin class for generating samples from posterior distribution.

Works using both minibatches and full data.

Methods table#

sample_posterior([num_samples, ...])

Summarise posterior distribution.

Methods#

sample_posterior

PyroSampleMixin.sample_posterior(num_samples=1000, return_sites=None, use_gpu=None, batch_size=None, return_observed=False, return_samples=False, summary_fun=None)[source]#

Summarise posterior distribution.

Generate samples from posterior distribution for each parameter and compute mean, 5%/95% quantiles, standard deviation.

Parameters:
  • num_samples (int (default: 1000)) – Number of posterior samples to generate.

  • return_sites (Optional[list] (default: None)) – List of variables for which to generate posterior samples, defaults to all variables.

  • use_gpu (Optional[bool] (default: None)) – Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).

  • batch_size (Optional[int] (default: None)) – Minibatch size for data loading into model. Defaults to scvi.settings.batch_size.

  • return_observed (bool (default: False)) – Return observed sites/variables? Observed count matrix can be very large so not returned by default.

  • return_samples (bool (default: False)) – Return all generated posterior samples in addition to sample mean, 5%/95% quantile and SD?

  • summary_fun (Optional[Dict[str, Callable]] (default: None)) – a dict in the form {“means”: np.mean, “std”: np.std} which specifies posterior distribution summaries to compute and which names to use. See below for default returns.

Returns:

post_sample_means: Dict[str, np.ndarray]

Mean of the posterior distribution for each variable, a dictionary of numpy arrays for each variable;

post_sample_q05: Dict[str, np.ndarray]

5% quantile of the posterior distribution for each variable;

post_sample_q05: Dict[str, np.ndarray]

95% quantile of the posterior distribution for each variable;

post_sample_q05: Dict[str, np.ndarray]

Standard deviation of the posterior distribution for each variable;

posterior_samples: Optional[Dict[str, np.ndarray]]

Posterior distribution samples for each variable as numpy arrays of shape (n_samples, …) (Optional).

Notes

Note for developers: requires overwritten list_obs_plate_vars property. which lists observation/minibatch plate name and variables. See list_obs_plate_vars for details of the variables it should contain. This dictionary can be returned by model class property self.module.model.list_obs_plate_vars to keep all model-specific variables in one place.