scvi.module.base.JaxBaseModuleClass#

class scvi.module.base.JaxBaseModuleClass(parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Abstract class for Jax-based scvi-tools modules.

The JaxBaseModuleClass provides an interface for Jax-backed modules consistent with the BaseModuleClass.

Any subclass must has a training parameter in its constructor, as well as use the @flax_configure decorator.

Children of JaxBaseModuleClass should use the instance attribute self.training to appropriately modify the behavior of the model whether it is in training or evaluation mode.

Attributes table#

device

name

params

parent

required_rngs

Returns a tuple of rng sequence names required for this Flax module.

rngs

Dictionary of RNGs mapping required RNG name to RNG values.

scope

state

Methods table#

as_bound()

Module bound with parameters learned from training.

as_numpy_array(x)

Converts a jax device array to a numpy array.

configure()

Add necessary attrs.

eval()

Switch to evaluation mode.

generative(*args, **kwargs)

Run the generative model.

get_jit_inference_fn([...])

Create a method to run inference using the bound module.

inference(*args, **kwargs)

Run the recognition model.

load_state_dict(state_dict)

Load a state dictionary into a train state.

loss(*args, **kwargs)

Compute the loss for a minibatch of data.

on_load(model)

Callback function run in load().

setup()

Flax setup method.

state_dict()

Returns a serialized version of the train state as a dictionary.

to(device)

Move module to device.

train()

Switch to train mode.

Attributes#

JaxBaseModuleClass.device[source]#
JaxBaseModuleClass.name: Optional[str] = None#
JaxBaseModuleClass.params[source]#
JaxBaseModuleClass.parent: Union[Type[Module], Scope, Type[_Sentinel], None] = None#
JaxBaseModuleClass.required_rngs[source]#

Returns a tuple of rng sequence names required for this Flax module.

JaxBaseModuleClass.rngs[source]#

Dictionary of RNGs mapping required RNG name to RNG values.

Calls self._split_rngs() resulting in newly generated RNGs on every reference to self.rngs.

JaxBaseModuleClass.scope: Optional[Scope] = None#
JaxBaseModuleClass.state[source]#

Methods#

JaxBaseModuleClass.as_bound()[source]#

Module bound with parameters learned from training.

Return type:

JaxBaseModuleClass

static JaxBaseModuleClass.as_numpy_array(x)[source]#

Converts a jax device array to a numpy array.

JaxBaseModuleClass.configure()[source]#

Add necessary attrs.

Return type:

None

JaxBaseModuleClass.eval()[source]#

Switch to evaluation mode. Emulates Pytorch’s interface.

abstract JaxBaseModuleClass.generative(*args, **kwargs)[source]#

Run the generative model.

This function should return the parameters associated with the likelihood of the data. This is typically written as \(p(x|z)\).

This function should return a dictionary with str keys and ndarray values.

Return type:

dict[str, Array | Distribution]

JaxBaseModuleClass.get_jit_inference_fn(get_inference_input_kwargs=None, inference_kwargs=None)[source]#

Create a method to run inference using the bound module.

Parameters:
  • get_inference_input_kwargs (dict[str, Any] | None (default: None)) – Keyword arguments to pass to subclass _get_inference_input

  • inference_kwargs (dict[str, Any] | None (default: None)) – Keyword arguments for subclass inference method

Return type:

Callable[[dict[str, Array], dict[str, Array]], dict[str, Array]]

Returns:

A callable taking rngs and array_dict as input and returning the output of the inference method. This callable runs _get_inference_input.

abstract JaxBaseModuleClass.inference(*args, **kwargs)[source]#

Run the recognition model.

In the case of variational inference, this function will perform steps related to computing variational distribution parameters. In a VAE, this will involve running data through encoder networks.

This function should return a dictionary with str keys and ndarray values.

Return type:

dict[str, Array | Distribution]

JaxBaseModuleClass.load_state_dict(state_dict)[source]#

Load a state dictionary into a train state.

abstract JaxBaseModuleClass.loss(*args, **kwargs)[source]#

Compute the loss for a minibatch of data.

This function uses the outputs of the inference and generative functions to compute a loss. This many optionally include other penalty terms, which should be computed here.

This function should return an object of type LossOutput.

Return type:

LossOutput

static JaxBaseModuleClass.on_load(model)[source]#

Callback function run in load().

Run one training step prior to loading state dict in order to initialize params.

abstract JaxBaseModuleClass.setup()[source]#

Flax setup method.

With scvi-tools we prefer to use the setup parameterization of flax.linen Modules. This lends the interface to be more like PyTorch. More about this can be found here:

https://flax.readthedocs.io/en/latest/design_notes/setup_or_nncompact.html

JaxBaseModuleClass.state_dict()[source]#

Returns a serialized version of the train state as a dictionary.

Return type:

dict[str, Any]

JaxBaseModuleClass.to(device)[source]#

Move module to device.

JaxBaseModuleClass.train()[source]#

Switch to train mode. Emulates Pytorch’s interface.