scvi.model.base.PyroSampleMixin#
- class scvi.model.base.PyroSampleMixin[source]#
Mixin class for generating samples from posterior distribution.
Works using both minibatches and full data.
Methods table#
|
Summarise posterior distribution. |
Methods#
sample_posterior
- PyroSampleMixin.sample_posterior(num_samples=1000, return_sites=None, use_gpu=None, batch_size=None, return_observed=False, return_samples=False, summary_fun=None)[source]#
Summarise posterior distribution.
Generate samples from posterior distribution for each parameter and compute mean, 5%/95% quantiles, standard deviation.
- Parameters:
num_samples (int) – Number of posterior samples to generate.
return_sites (Optional[list]) – List of variables for which to generate posterior samples, defaults to all variables.
use_gpu (Optional[bool]) – Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).
batch_size (Optional[int]) – Minibatch size for data loading into model. Defaults to
scvi.settings.batch_size
.return_observed (bool) – Return observed sites/variables? Observed count matrix can be very large so not returned by default.
return_samples (bool) – Return all generated posterior samples in addition to sample mean, 5%/95% quantile and SD?
summary_fun (Optional[Dict[str, Callable]]) – a dict in the form {“means”: np.mean, “std”: np.std} which specifies posterior distribution summaries to compute and which names to use. See below for default returns.
- Returns:
- post_sample_means: Dict[str,
np.ndarray
] Mean of the posterior distribution for each variable, a dictionary of numpy arrays for each variable;
- post_sample_q05: Dict[str,
np.ndarray
] 5% quantile of the posterior distribution for each variable;
- post_sample_q05: Dict[str,
np.ndarray
] 95% quantile of the posterior distribution for each variable;
- post_sample_q05: Dict[str,
np.ndarray
] Standard deviation of the posterior distribution for each variable;
- posterior_samples: Optional[Dict[str,
np.ndarray
]] Posterior distribution samples for each variable as numpy arrays of shape
(n_samples, ...)
(Optional).
- post_sample_means: Dict[str,
Notes
Note for developers: requires overwritten
list_obs_plate_vars
property. which lists observation/minibatch plate name and variables. Seelist_obs_plate_vars
for details of the variables it should contain. This dictionary can be returned by model class propertyself.module.model.list_obs_plate_vars
to keep all model-specific variables in one place.