scvi.model.base.PyroSampleMixin#
- class scvi.model.base.PyroSampleMixin[source]#
Mixin class for generating samples from posterior distribution.
Works using both minibatches and full data.
Methods table#
|
Summarise posterior distribution. |
Methods#
sample_posterior
- PyroSampleMixin.sample_posterior(num_samples=1000, return_sites=None, use_gpu=None, batch_size=None, return_observed=False, return_samples=False, summary_fun=None)[source]#
Summarise posterior distribution.
Generate samples from posterior distribution for each parameter and compute mean, 5%/95% quantiles, standard deviation.
- Parameters:
num_samples (
int
(default:1000
)) – Number of posterior samples to generate.return_sites (
Optional
[list
] (default:None
)) – List of variables for which to generate posterior samples, defaults to all variables.use_gpu (
Optional
[bool
] (default:None
)) – Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).batch_size (
Optional
[int
] (default:None
)) – Minibatch size for data loading into model. Defaults to scvi.settings.batch_size.return_observed (
bool
(default:False
)) – Return observed sites/variables? Observed count matrix can be very large so not returned by default.return_samples (
bool
(default:False
)) – Return all generated posterior samples in addition to sample mean, 5%/95% quantile and SD?summary_fun (
Optional
[Dict
[str
,Callable
]] (default:None
)) – a dict in the form {“means”: np.mean, “std”: np.std} which specifies posterior distribution summaries to compute and which names to use. See below for default returns.
- Returns:
- post_sample_means: Dict[str,
np.ndarray
] Mean of the posterior distribution for each variable, a dictionary of numpy arrays for each variable;
- post_sample_q05: Dict[str,
np.ndarray
] 5% quantile of the posterior distribution for each variable;
- post_sample_q05: Dict[str,
np.ndarray
] 95% quantile of the posterior distribution for each variable;
- post_sample_q05: Dict[str,
np.ndarray
] Standard deviation of the posterior distribution for each variable;
- posterior_samples: Optional[Dict[str,
np.ndarray
]] Posterior distribution samples for each variable as numpy arrays of shape (n_samples, …) (Optional).
- post_sample_means: Dict[str,
Notes
Note for developers: requires overwritten
list_obs_plate_vars
property. which lists observation/minibatch plate name and variables. Seelist_obs_plate_vars
for details of the variables it should contain. This dictionary can be returned by model class property self.module.model.list_obs_plate_vars to keep all model-specific variables in one place.