scvi.external.mrvi_jax.JaxMRVAE#
- class scvi.external.mrvi_jax.JaxMRVAE(n_input, n_sample, n_batch, n_labels, n_latent=30, n_latent_u=10, encoder_n_hidden=128, encoder_n_layers=2, z_u_prior=True, z_u_prior_scale=0.0, u_prior_scale=0.0, u_prior_mixture=True, u_prior_mixture_k=20, learn_z_u_prior_scale=False, scale_observations=False, px_kwargs=None, qz_kwargs=None, qu_kwargs=None, training=True, n_obs_per_sample=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
JaxBaseModuleClassMulti-resolution Variational Inference (MrVI) module.
- Parameters:
n_input (
int) – Number of input features.n_sample (
int) – Number of samples.n_batch (
int) – Number of batches.n_labels (
int) – Number of labels.n_latent (
int(default:30)) – Number of latent variables forz.n_latent_u (
int(default:10)) – Number of latent variables foru.encoder_n_hidden (
int(default:128)) – Number of hidden units in the encoder.encoder_n_layers (
int(default:2)) – Number of layers in the encoder.z_u_prior (
bool(default:True)) – Whether to place a Gaussian prior onzgivenu.z_u_prior_scale (
float(default:0.0)) – Natural log of the scale parameter of the Gaussian prior placed onzgivenu. Only applies oflearn_z_u_prior_scaleisFalse.u_prior_scale (
float(default:0.0)) – Natural log of the scale parameter of the Gaussian prior placed onu. Ifu_prior_mixtureisTrue, this scale applies to each mixture component distribution.u_prior_mixture (
bool(default:True)) – Whether to use a mixture of Gaussians prior foru.u_prior_mixture_k (
int(default:20)) – Number of mixture components to use for the mixture of Gaussians prior onu.learn_z_u_prior_scale (
bool(default:False)) – Whether to learn the scale parameter of the prior distribution ofzgivenu.scale_observations (
bool(default:False)) – Whether to scale the loss associated with each observation by the total number of observations linked to the associated sample.px_kwargs (
dict|None(default:None)) – Keyword arguments for the generative model.qz_kwargs (
dict|None(default:None)) – Keyword arguments for the inference model fromutoz.qu_kwargs (
dict|None(default:None)) – Keyword arguments for the inference model fromxtou.training (
bool(default:True)) – Whether the model is in training mode.n_obs_per_sample (
Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray,None] (default:None)) – Number of observations per sample.
Attributes table#
Returns a tuple of rng sequence names required for this Flax module. |
|
Methods table#
|
Compute normalized gene expression from observations using predefined eps |
|
Generative model. |
|
Latent variable inference. |
|
Compute the loss function value. |
|
Flax setup method. |
Attributes#
- JaxMRVAE.encoder_n_layers: int = 2#
- JaxMRVAE.learn_z_u_prior_scale: bool = False#
- JaxMRVAE.n_latent: int = 30#
- JaxMRVAE.n_latent_u: int = 10#
- JaxMRVAE.n_obs_per_sample: jax.typing.ArrayLike | None = None#
- JaxMRVAE.px_kwargs: dict | None = None#
- JaxMRVAE.qu_kwargs: dict | None = None#
- JaxMRVAE.qz_kwargs: dict | None = None#
- JaxMRVAE.scale_observations: bool = False#
- JaxMRVAE.scope: Scope | None = None#
- JaxMRVAE.training: bool = True#
- JaxMRVAE.u_prior_mixture: bool = True#
- JaxMRVAE.u_prior_mixture_k: int = 20#
- JaxMRVAE.u_prior_scale: float = 0.0#
- JaxMRVAE.z_u_prior: bool = True#
- JaxMRVAE.z_u_prior_scale: float = 0.0#
- JaxMRVAE.n_input: int#
- JaxMRVAE.n_sample: int#
- JaxMRVAE.n_batch: int#
- JaxMRVAE.n_labels: int#
Methods#
- JaxMRVAE.compute_h_from_x_eps(x, sample_index, batch_index, extra_eps, cf_sample=None, mc_samples=10)[source]#
Compute normalized gene expression from observations using predefined eps
- JaxMRVAE.inference(x, sample_index, mc_samples=None, cf_sample=None, use_mean=False)[source]#
Latent variable inference.
- JaxMRVAE.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0)[source]#
Compute the loss function value.
- Return type:
LossOutput
- JaxMRVAE.setup()[source]#
Flax setup method.
With scvi-tools we prefer to use the setup parameterization of flax.linen Modules. This tends the interface to be more like PyTorch. More about this can be found here:
https://flax.readthedocs.io/en/latest/design_notes/setup_or_nncompact.html