scvi.module.JaxVAE#
- class scvi.module.JaxVAE(n_input, n_batch, n_hidden=128, n_latent=30, dropout_rate=0.0, n_layers=1, gene_likelihood='nb', eps=1e-08, training=True, n_continuous_cov=0, n_cats_per_cov=(), parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
JaxBaseModuleClassVariational autoencoder model.
Attributes table#
Returns a tuple of rng sequence names required for this Flax module. |
|
Methods table#
|
Run generative model. |
|
Run inference model. |
|
Compute loss. |
|
Setup model. |
Attributes#
- JaxVAE.scope: Scope | None = None#
Methods#
- JaxVAE.generative(x, z, batch_index, cont_covs=None, cat_covs=None)[source]#
Run generative model.
- Return type: