scvi.module.base.JaxModuleWrapper#
- class scvi.module.base.JaxModuleWrapper(module_cls, seed=0, **module_kwargs)[source]#
Bases:
object
Wrapper class for Flax (Jax-backed) modules used to interact with model classes.
This class maintains all state necessary for training and updating the state via the Flax module. The Flax module should remain stateless. In addition, the
JaxModuleWrapper
duck-types the methods ofBaseModuleClass
, which supports PyTorch-backed modules, to provide a consistent interface forBaseModelClass
.- Parameters:
module_cls (
JaxBaseModuleClass
) – Flax module class to wrap.seed (
int
(default:0
)) – Random seed to initialize Jax RNGs with.**module_kwargs – Keyword arguments that will be used to initialize
module_cls
.
Attributes table#
Apply function of the Flax module. |
|
Init function of the Flax module. |
|
Loss function of the Flax module. |
|
|
|
Dictionary of RNGs mapping required RNG name to RNG values. |
|
|
|
Train state containing learned parameter values from training. |
|
Whether or not the Flax module is in training mode. |
Methods table#
|
Switch to evaluation mode. |
|
Returns a method to run inference using the bound module. |
|
Load a state dictionary into a train state. |
|
Callback function run in :method:`~scvi.model.base.BaseModelClass.load` prior to loading module state dict. |
Returns a serialized version of the train state as a dictionary. |
|
|
Move module to device. |
|
Switch to train mode. |
Attributes#
apply
device
init
loss
module
params
- JaxModuleWrapper.params[source]#
- Return type:
FrozenDict
[str
,Any
]
rngs
- JaxModuleWrapper.rngs[source]#
Dictionary of RNGs mapping required RNG name to RNG values.
Calls
self._split_rngs()
resulting in newly generated RNGs on every reference toself.rngs
.
state
- JaxModuleWrapper.state[source]#
- Return type:
FrozenDict
[str
,Any
]
train_state
- JaxModuleWrapper.train_state[source]#
Train state containing learned parameter values from training.
- Return type:
TrainStateWithState
training
Methods#
eval
get_inference_fn
- JaxModuleWrapper.get_inference_fn(mc_samples=1)[source]#
Returns a method to run inference using the bound module.
- Parameters:
mc_samples (
int
(default:1
)) – Number of Monte Carlo samples to run for each input.
load_state_dict
on_load
- static JaxModuleWrapper.on_load(model)[source]#
Callback function run in :method:`~scvi.model.base.BaseModelClass.load` prior to loading module state dict.
For some Pyro modules with AutoGuides, run one training step prior to loading state dict.
state_dict
- JaxModuleWrapper.state_dict()[source]#
Returns a serialized version of the train state as a dictionary.
to
train