scvi.module.base.PyroBaseModuleClass.create_predictive

PyroBaseModuleClass.create_predictive(model=None, posterior_samples=None, guide=None, num_samples=None, return_sites=(), parallel=False)[source]

Creates a Predictive object.

Parameters
model : Callable | NoneOptional[Callable] (default: None)

Python callable containing Pyro primitives. Defaults to self.model.

posterior_samples : dict | NoneOptional[dict] (default: None)

Dictionary of samples from the posterior

guide : Callable | NoneOptional[Callable] (default: None)

Optional guide to get posterior samples of sites not present in posterior_samples. Defaults to self.guide

num_samples : int | NoneOptional[int] (default: None)

Number of samples to draw from the predictive distribution. This argument has no effect if posterior_samples is non-empty, in which case, the leading dimension size of samples in posterior_samples is used.

return_sites : Tuple[str]Tuple[str] (default: ())

Sites to return; by default only sample sites not present in posterior_samples are returned.

parallel : boolbool (default: False)

predict in parallel by wrapping the existing model in an outermost plate messenger. Note that this requires that the model has all batch dims correctly annotated via plate. Default is False.

Return type

PredictivePredictive