scvi.module.JaxVAE#

class scvi.module.JaxVAE(n_input, n_batch, n_hidden=128, n_latent=30, dropout_rate=0.0, n_layers=1, gene_likelihood='nb', eps=1e-08, training=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: JaxBaseModuleClass

Variational autoencoder model.

Attributes table#

dropout_rate

eps

gene_likelihood

n_hidden

n_latent

n_layers

name

parent

required_rngs

Returns a tuple of rng sequence names required for this Flax module.

scope

training

n_input

n_batch

Methods table#

generative(x, z, batch_index)

Run generative model.

inference(x[, n_samples])

Run inference model.

loss(tensors, inference_outputs, ...[, ...])

Compute loss.

setup()

Setup model.

Attributes#

JaxVAE.dropout_rate: Tunable_[float] = 0.0#
JaxVAE.eps: Tunable_[float] = 1e-08#
JaxVAE.gene_likelihood: Tunable_[str] = 'nb'#
JaxVAE.n_hidden: Tunable_[int] = 128#
JaxVAE.n_latent: Tunable_[int] = 30#
JaxVAE.n_layers: Tunable_[int] = 1#
JaxVAE.name: Optional[str] = None#
JaxVAE.parent: Union[Type[Module], Scope, Type[_Sentinel], None] = None#
JaxVAE.required_rngs[source]#
JaxVAE.scope: Optional[Scope] = None#
JaxVAE.training: bool = True#
JaxVAE.n_input: int#
JaxVAE.n_batch: int#

Methods#

JaxVAE.generative(x, z, batch_index)[source]#

Run generative model.

Return type:

dict

JaxVAE.inference(x, n_samples=1)[source]#

Run inference model.

Return type:

dict

JaxVAE.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0)[source]#

Compute loss.

JaxVAE.setup()[source]#

Setup model.