scvi.module.JaxVAE#

class scvi.module.JaxVAE(n_input, n_batch, n_hidden=128, n_latent=30, dropout_rate=0.0, n_layers=1, gene_likelihood='nb', eps=1e-08, training=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: JaxBaseModuleClass

Variational autoencoder model.

Attributes table#

dropout_rate

eps

gene_likelihood

n_hidden

n_latent

n_layers

name

parent

required_rngs

Returns a tuple of rng sequence names required for this Flax module.

scope

training

n_input

n_batch

Methods table#

generative(x, z, batch_index)

Run generative model.

inference(x[, n_samples])

Run inference model.

loss(tensors, inference_outputs, ...[, ...])

Compute loss.

setup()

Setup model.

Attributes#

dropout_rate

JaxVAE.dropout_rate: float = 0.0#

eps

JaxVAE.eps: float = 1e-08#

gene_likelihood

JaxVAE.gene_likelihood: str = 'nb'#

n_hidden

JaxVAE.n_hidden: int = 128#

n_latent

JaxVAE.n_latent: int = 30#

n_layers

JaxVAE.n_layers: int = 1#

name

JaxVAE.name: str = None#

parent

JaxVAE.parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>#

required_rngs

JaxVAE.required_rngs[source]#

scope

JaxVAE.scope = None#

training

JaxVAE.training: bool = True#

n_input

JaxVAE.n_input: int#

n_batch

JaxVAE.n_batch: int#

Methods#

generative

JaxVAE.generative(x, z, batch_index)[source]#

Run generative model.

Return type:

dict

inference

JaxVAE.inference(x, n_samples=1)[source]#

Run inference model.

Return type:

dict

loss

JaxVAE.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0)[source]#

Compute loss.

setup

JaxVAE.setup()[source]#

Setup model.