scvi.external.mrvi.MRVAE#
- class scvi.external.mrvi.MRVAE(n_input, n_sample, n_batch, n_labels, n_latent=30, n_latent_u=10, encoder_n_hidden=128, encoder_n_layers=2, z_u_prior=True, z_u_prior_scale=0.0, u_prior_scale=0.0, u_prior_mixture=True, u_prior_mixture_k=20, learn_z_u_prior_scale=False, scale_observations=False, px_kwargs=None, qz_kwargs=None, qu_kwargs=None, training=True, n_obs_per_sample=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
JaxBaseModuleClass
Multi-resolution Variational Inference (MrVI) module.
- Parameters:
n_input (
int
) – Number of input features.n_sample (
int
) – Number of samples.n_batch (
int
) – Number of batches.n_labels (
int
) – Number of labels.n_latent (
int
(default:30
)) – Number of latent variables forz
.n_latent_u (
int
(default:10
)) – Number of latent variables foru
.encoder_n_hidden (
int
(default:128
)) – Number of hidden units in the encoder.encoder_n_layers (
int
(default:2
)) – Number of layers in the encoder.z_u_prior (
bool
(default:True
)) – Whether to place a Gaussian prior onz
givenu
.z_u_prior_scale (
float
(default:0.0
)) – Natural log of the scale parameter of the Gaussian prior placed onz
givenu
. Only applies oflearn_z_u_prior_scale
isFalse
.u_prior_scale (
float
(default:0.0
)) – Natural log of the scale parameter of the Gaussian prior placed onu
. Ifu_prior_mixture
isTrue
, this scale applies to each mixture component distribution.u_prior_mixture (
bool
(default:True
)) – Whether to use a mixture of Gaussians prior foru
.u_prior_mixture_k (
int
(default:20
)) – Number of mixture components to use for the mixture of Gaussians prior onu
.learn_z_u_prior_scale (
bool
(default:False
)) – Whether to learn the scale parameter of the prior distribution ofz
givenu
.scale_observations (
bool
(default:False
)) – Whether to scale the loss associated with each observation by the total number of observations linked to the associated sample.px_kwargs (
dict
|None
(default:None
)) – Keyword arguments for the generative model.qz_kwargs (
dict
|None
(default:None
)) – Keyword arguments for the inference model fromu
toz
.qu_kwargs (
dict
|None
(default:None
)) – Keyword arguments for the inference model fromx
tou
.training (
bool
(default:True
)) – Whether the model is in training mode.n_obs_per_sample (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
,None
] (default:None
)) – Number of observations per sample.
Attributes table#
Returns a tuple of rng sequence names required for this Flax module. |
|
Methods table#
|
Compute normalized gene expression from observations using predefined eps |
|
Generative model. |
|
Latent variable inference. |
|
Compute the loss function value. |
|
Flax setup method. |
Attributes#
- MRVAE.encoder_n_layers: int = 2#
- MRVAE.learn_z_u_prior_scale: bool = False#
- MRVAE.n_latent: int = 30#
- MRVAE.n_latent_u: int = 10#
- MRVAE.n_obs_per_sample: jax.typing.ArrayLike | None = None#
- MRVAE.px_kwargs: dict | None = None#
- MRVAE.qu_kwargs: dict | None = None#
- MRVAE.qz_kwargs: dict | None = None#
- MRVAE.scale_observations: bool = False#
- MRVAE.scope: Scope | None = None#
- MRVAE.training: bool = True#
- MRVAE.u_prior_mixture: bool = True#
- MRVAE.u_prior_mixture_k: int = 20#
- MRVAE.u_prior_scale: float = 0.0#
- MRVAE.z_u_prior: bool = True#
- MRVAE.z_u_prior_scale: float = 0.0#
- MRVAE.n_input: int#
- MRVAE.n_sample: int#
- MRVAE.n_batch: int#
- MRVAE.n_labels: int#
Methods#
- MRVAE.compute_h_from_x_eps(x, sample_index, batch_index, extra_eps, cf_sample=None, mc_samples=10)[source]#
Compute normalized gene expression from observations using predefined eps
- MRVAE.inference(x, sample_index, mc_samples=None, cf_sample=None, use_mean=False)[source]#
Latent variable inference.
- MRVAE.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0)[source]#
Compute the loss function value.
- Return type:
- MRVAE.setup()[source]#
Flax setup method.
With scvi-tools we prefer to use the setup parameterization of flax.linen Modules. This lends the interface to be more like PyTorch. More about this can be found here:
https://flax.readthedocs.io/en/latest/design_notes/setup_or_nncompact.html