scvi.external.mrvi.MRVAE#

class scvi.external.mrvi.MRVAE(n_input, n_sample, n_batch, n_labels, n_latent=30, n_latent_u=10, encoder_n_hidden=128, encoder_n_layers=2, z_u_prior=True, z_u_prior_scale=0.0, u_prior_scale=0.0, u_prior_mixture=True, u_prior_mixture_k=20, learn_z_u_prior_scale=False, scale_observations=False, px_kwargs=None, qz_kwargs=None, qu_kwargs=None, training=True, n_obs_per_sample=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: JaxBaseModuleClass

Multi-resolution Variational Inference (MrVI) module.

Parameters:
  • n_input (int) – Number of input features.

  • n_sample (int) – Number of samples.

  • n_batch (int) – Number of batches.

  • n_labels (int) – Number of labels.

  • n_latent (int (default: 30)) – Number of latent variables for z.

  • n_latent_u (int (default: 10)) – Number of latent variables for u.

  • encoder_n_hidden (int (default: 128)) – Number of hidden units in the encoder.

  • encoder_n_layers (int (default: 2)) – Number of layers in the encoder.

  • z_u_prior (bool (default: True)) – Whether to place a Gaussian prior on z given u.

  • z_u_prior_scale (float (default: 0.0)) – Natural log of the scale parameter of the Gaussian prior placed on z given u. Only applies of learn_z_u_prior_scale is False.

  • u_prior_scale (float (default: 0.0)) – Natural log of the scale parameter of the Gaussian prior placed on u. If u_prior_mixture is True, this scale applies to each mixture component distribution.

  • u_prior_mixture (bool (default: True)) – Whether to use a mixture of Gaussians prior for u.

  • u_prior_mixture_k (int (default: 20)) – Number of mixture components to use for the mixture of Gaussians prior on u.

  • learn_z_u_prior_scale (bool (default: False)) – Whether to learn the scale parameter of the prior distribution of z given u.

  • scale_observations (bool (default: False)) – Whether to scale the loss associated with each observation by the total number of observations linked to the associated sample.

  • px_kwargs (dict | None (default: None)) – Keyword arguments for the generative model.

  • qz_kwargs (dict | None (default: None)) – Keyword arguments for the inference model from u to z.

  • qu_kwargs (dict | None (default: None)) – Keyword arguments for the inference model from x to u.

  • training (bool (default: True)) – Whether the model is in training mode.

  • n_obs_per_sample (Union[Array, ndarray, bool_, number, bool, int, float, complex, None] (default: None)) – Number of observations per sample.

Attributes table#

Methods table#

compute_h_from_x_eps(x, sample_index, ...[, ...])

Compute normalized gene expression from observations using predefined eps

generative(z, library, batch_index, label_index)

Generative model.

inference(x, sample_index[, mc_samples, ...])

Latent variable inference.

loss(tensors, inference_outputs, ...[, ...])

Compute the loss function value.

setup()

Flax setup method.

Attributes#

MRVAE.encoder_n_hidden: int = 128#
MRVAE.encoder_n_layers: int = 2#
MRVAE.learn_z_u_prior_scale: bool = False#
MRVAE.n_latent: int = 30#
MRVAE.n_latent_u: int = 10#
MRVAE.n_obs_per_sample: jax.typing.ArrayLike | None = None#
MRVAE.name: Optional[str] = None#
MRVAE.parent: Union[Type[Module], Scope, Type[_Sentinel], None] = None#
MRVAE.px_kwargs: dict | None = None#
MRVAE.qu_kwargs: dict | None = None#
MRVAE.qz_kwargs: dict | None = None#
MRVAE.required_rngs[source]#
MRVAE.scale_observations: bool = False#
MRVAE.scope: Optional[Scope] = None#
MRVAE.training: bool = True#
MRVAE.u_prior_mixture: bool = True#
MRVAE.u_prior_mixture_k: int = 20#
MRVAE.u_prior_scale: float = 0.0#
MRVAE.z_u_prior: bool = True#
MRVAE.z_u_prior_scale: float = 0.0#
MRVAE.n_input: int#
MRVAE.n_sample: int#
MRVAE.n_batch: int#
MRVAE.n_labels: int#

Methods#

MRVAE.compute_h_from_x_eps(x, sample_index, batch_index, extra_eps, cf_sample=None, mc_samples=10)[source]#

Compute normalized gene expression from observations using predefined eps

MRVAE.generative(z, library, batch_index, label_index)[source]#

Generative model.

Return type:

dict[str, Array | Distribution]

MRVAE.inference(x, sample_index, mc_samples=None, cf_sample=None, use_mean=False)[source]#

Latent variable inference.

Return type:

dict[str, Array | Distribution]

MRVAE.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0)[source]#

Compute the loss function value.

Return type:

LossOutput

MRVAE.setup()[source]#

Flax setup method.

With scvi-tools we prefer to use the setup parameterization of flax.linen Modules. This lends the interface to be more like PyTorch. More about this can be found here:

https://flax.readthedocs.io/en/latest/design_notes/setup_or_nncompact.html