scvi.module.base.JaxBaseModuleClass#
- class scvi.module.base.JaxBaseModuleClass(parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
Module
Abstract class for Jax-based scvi-tools modules.
The
JaxBaseModuleClass
provides an interface for Jax-backed modules consistent with theBaseModuleClass
.Any subclass must has a training parameter in its constructor, as well as use the @flax_configure decorator.
Children of
JaxBaseModuleClass
should use the instance attributeself.training
to appropriately modify the behavior of the model whether it is in training or evaluation mode.
Attributes table#
Methods table#
|
Module bound with parameters learned from training. |
Converts a jax device array to a numpy array. |
|
Add necessary attrs. |
|
|
Switch to evaluation mode. |
|
Run the generative model. |
|
Create a method to run inference using the bound module. |
|
Run the recognition model. |
|
Load a state dictionary into a train state. |
|
Compute the loss for a minibatch of data. |
|
Callback function run in |
|
Flax setup method. |
Returns a serialized version of the train state as a dictionary. |
|
|
Move module to device. |
|
Switch to train mode. |
Attributes#
- JaxBaseModuleClass.required_rngs[source]#
Returns a tuple of rng sequence names required for this Flax module.
- JaxBaseModuleClass.rngs[source]#
Dictionary of RNGs mapping required RNG name to RNG values.
Calls
self._split_rngs()
resulting in newly generated RNGs on every reference toself.rngs
.
- JaxBaseModuleClass.scope: Scope | None = None#
Methods#
- JaxBaseModuleClass.as_bound()[source]#
Module bound with parameters learned from training.
- Return type:
- abstract JaxBaseModuleClass.generative(*args, **kwargs)[source]#
Run the generative model.
This function should return the parameters associated with the likelihood of the data. This is typically written as \(p(x|z)\).
This function should return a dictionary with str keys and
ndarray
values.
- JaxBaseModuleClass.get_jit_inference_fn(get_inference_input_kwargs=None, inference_kwargs=None)[source]#
Create a method to run inference using the bound module.
- abstract JaxBaseModuleClass.inference(*args, **kwargs)[source]#
Run the recognition model.
In the case of variational inference, this function will perform steps related to computing variational distribution parameters. In a VAE, this will involve running data through encoder networks.
This function should return a dictionary with str keys and
ndarray
values.
- abstract JaxBaseModuleClass.loss(*args, **kwargs)[source]#
Compute the loss for a minibatch of data.
This function uses the outputs of the inference and generative functions to compute a loss. This many optionally include other penalty terms, which should be computed here.
This function should return an object of type
LossOutput
.- Return type:
- static JaxBaseModuleClass.on_load(model)[source]#
Callback function run in
load()
.Run one training step prior to loading state dict in order to initialize params.
- abstract JaxBaseModuleClass.setup()[source]#
Flax setup method.
With scvi-tools we prefer to use the setup parameterization of flax.linen Modules. This lends the interface to be more like PyTorch. More about this can be found here:
https://flax.readthedocs.io/en/latest/design_notes/setup_or_nncompact.html