scvi.module.JaxVAE#
- class scvi.module.JaxVAE(n_input, n_batch, n_hidden=128, n_latent=30, dropout_rate=0.0, n_layers=1, gene_likelihood='nb', eps=1e-08, training=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
JaxBaseModuleClass
Variational autoencoder model.
- Parameters:
Attributes table#
Returns a tuple of rng sequence names required for this Flax module. |
|
Methods table#
|
Run generative model. |
|
Run inference model. |
|
Compute loss. |
|
Setup model. |
Attributes#
dropout_rate
eps
gene_likelihood
n_hidden
n_latent
n_layers
name
parent
required_rngs
scope
- JaxVAE.scope: Optional[Scope] = None#
training
n_input
n_batch
Methods#
generative
inference
loss
- JaxVAE.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0)[source]#
Compute loss.
- Parameters:
kl_weight (float) –
setup