scvi.model.base.PyroSampleMixin.sample_posterior

PyroSampleMixin.sample_posterior(num_samples=1000, return_sites=None, use_gpu=None, batch_size=None, return_observed=False, return_samples=False, summary_fun=None)[source]

Summarise posterior distribution.

Generate samples from posterior distribution for each parameter and compute mean, 5%/95% quantiles, standard deviation.

Parameters
num_samples : intint (default: 1000)

Number of posterior samples to generate.

return_site

List of variables for which to generate posterior samples, defaults to all variables.

use_gpu : bool | NoneOptional[bool] (default: None)

Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).

batch_size : int | NoneOptional[int] (default: None)

Minibatch size for data loading into model. Defaults to scvi.settings.batch_size.

return_observed : boolbool (default: False)

Return observed sites/variables? Observed count matrix can be very large so not returned by default.

return_samples : boolbool (default: False)

Return samples in addition to sample mean, 5%/95% quantile and SD?

summary_fun : {str: Callable} | NoneOptional[Dict[str, Callable]] (default: None)

a dict in the form {“means”: np.mean, “std”: np.std} which specifies posterior distribution summaries to compute and which names to use. See below for default returns.

Returns

post_sample_means: Dict[str, np.ndarray]

Mean of the posterior distribution for each variable, a dictionary of numpy arrays for each variable;

post_sample_q05: Dict[str, np.ndarray]

5% quantile of the posterior distribution for each variable;

post_sample_q05: Dict[str, np.ndarray]

95% quantile of the posterior distribution for each variable;

post_sample_q05: Dict[str, np.ndarray]

Standard deviation of the posterior distribution for each variable;

posterior_samples: Optional[Dict[str, np.ndarray]]

Posterior distribution samples for each variable as numpy arrays of shape (n_samples, …) (Optional).

Notes

Note for developers: requires overwritten list_obs_plate_vars property. which lists observation/minibatch plate name and variables. See list_obs_plate_vars for details of the variables it should contain. This dictionary can be returned by model class property self.module.model.list_obs_plate_vars to keep all model-specific variables in one place.