scvi.module.base.JaxModuleWrapper#

class scvi.module.base.JaxModuleWrapper(module_cls, seed=0, **module_kwargs)[source]#

Bases: object

Wrapper class for Flax (Jax-backed) modules used to interact with model classes.

This class maintains all state necessary for training and updating the state via the Flax module. The Flax module should remain stateless. In addition, the JaxModuleWrapper duck-types the methods of BaseModuleClass, which supports PyTorch-backed modules, to provide a consistent interface for BaseModelClass.

Parameters:
  • module_cls (JaxBaseModuleClass) – Flax module class to wrap.

  • seed (int (default: 0)) – Random seed to initialize Jax RNGs with.

  • **module_kwargs – Keyword arguments that will be used to initialize module_cls.

Attributes table#

apply

Apply function of the Flax module.

device

init

Init function of the Flax module.

loss

Loss function of the Flax module.

module

params

rtype:

FrozenDict[str, Any]

rngs

Dictionary of RNGs mapping required RNG name to RNG values.

state

rtype:

FrozenDict[str, Any]

train_state

Train state containing learned parameter values from training.

training

Whether or not the Flax module is in training mode.

Methods table#

eval()

Switch to evaluation mode.

get_inference_fn([mc_samples])

Returns a method to run inference using the bound module.

load_state_dict(state_dict)

Load a state dictionary into a train state.

on_load(model)

Callback function run in :method:`~scvi.model.base.BaseModelClass.load` prior to loading module state dict.

state_dict()

Returns a serialized version of the train state as a dictionary.

to(device)

Move module to device.

train()

Switch to train mode.

Attributes#

apply

JaxModuleWrapper.apply[source]#

Apply function of the Flax module.

device

JaxModuleWrapper.device[source]#

init

JaxModuleWrapper.init[source]#

Init function of the Flax module.

loss

JaxModuleWrapper.loss[source]#

Loss function of the Flax module.

module

JaxModuleWrapper.module[source]#

params

JaxModuleWrapper.params[source]#
Return type:

FrozenDict[str, Any]

rngs

JaxModuleWrapper.rngs[source]#

Dictionary of RNGs mapping required RNG name to RNG values.

Calls self._split_rngs() resulting in newly generated RNGs on every reference to self.rngs.

Return type:

Dict[str, Array]

state

JaxModuleWrapper.state[source]#
Return type:

FrozenDict[str, Any]

train_state

JaxModuleWrapper.train_state[source]#

Train state containing learned parameter values from training.

Return type:

TrainStateWithState

training

JaxModuleWrapper.training[source]#

Whether or not the Flax module is in training mode.

Methods#

eval

JaxModuleWrapper.eval()[source]#

Switch to evaluation mode. Emulates Pytorch’s interface.

get_inference_fn

JaxModuleWrapper.get_inference_fn(mc_samples=1)[source]#

Returns a method to run inference using the bound module.

Parameters:

mc_samples (int (default: 1)) – Number of Monte Carlo samples to run for each input.

load_state_dict

JaxModuleWrapper.load_state_dict(state_dict)[source]#

Load a state dictionary into a train state.

on_load

static JaxModuleWrapper.on_load(model)[source]#

Callback function run in :method:`~scvi.model.base.BaseModelClass.load` prior to loading module state dict.

For some Pyro modules with AutoGuides, run one training step prior to loading state dict.

state_dict

JaxModuleWrapper.state_dict()[source]#

Returns a serialized version of the train state as a dictionary.

Return type:

Dict[str, Any]

to

JaxModuleWrapper.to(device)[source]#

Move module to device.

train

JaxModuleWrapper.train()[source]#

Switch to train mode. Emulates Pytorch’s interface.