scvi.module.JaxVAE#
- class scvi.module.JaxVAE(n_input, n_batch, n_hidden=128, n_latent=30, dropout_rate=0.0, n_layers=1, gene_likelihood='nb', eps=1e-08, training=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
JaxBaseModuleClass
Variational autoencoder model.
Attributes table#
Returns a tuple of rng sequence names required for this Flax module. |
|
Methods table#
|
Run generative model. |
|
Run inference model. |
|
Compute loss. |
|
Setup model. |
Attributes#
dropout_rate
eps
gene_likelihood
n_hidden
n_latent
n_layers
name
parent
- JaxVAE.parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>#
required_rngs
scope
- JaxVAE.scope = None#
training
n_input
n_batch
Methods#
generative
inference
loss
setup