scvi.module.base.BaseModuleClass#
- class scvi.module.base.BaseModuleClass[source]#
Bases:
TunableMixin
,Module
Abstract class for scvi-tools modules.
Attributes table#
Methods table#
|
Forward pass through the network. |
|
Run the generative model. |
|
Run the recognition model. |
|
Compute the loss for a minibatch of data. |
|
Callback function run in |
|
Generate samples from the learned model. |
Attributes#
device
training
Methods#
forward
- BaseModuleClass.forward(tensors, get_inference_input_kwargs=None, get_generative_input_kwargs=None, inference_kwargs=None, generative_kwargs=None, loss_kwargs=None, compute_loss=True)[source]#
Forward pass through the network.
- Parameters:
tensors – tensors to pass through
get_inference_input_kwargs (dict | None) – Keyword args for
_get_inference_input()
get_generative_input_kwargs (dict | None) – Keyword args for
_get_generative_input()
inference_kwargs (dict | None) – Keyword args for
inference()
generative_kwargs (dict | None) – Keyword args for
generative()
loss_kwargs (dict | None) – Keyword args for
loss()
compute_loss – Whether to compute loss on forward pass. This adds another return value.
- Return type:
tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, scvi.module.base._base_module.LossOutput]
generative
- abstract BaseModuleClass.generative(*args, **kwargs)[source]#
Run the generative model.
This function should return the parameters associated with the likelihood of the data. This is typically written as \(p(x|z)\).
This function should return a dictionary with str keys and
Tensor
values.- Return type:
dict[str, torch.Tensor | torch.distributions.distribution.Distribution]
inference
- abstract BaseModuleClass.inference(*args, **kwargs)[source]#
Run the recognition model.
In the case of variational inference, this function will perform steps related to computing variational distribution parameters. In a VAE, this will involve running data through encoder networks.
This function should return a dictionary with str keys and
Tensor
values.- Return type:
dict[str, torch.Tensor | torch.distributions.distribution.Distribution]
loss
- abstract BaseModuleClass.loss(*args, **kwargs)[source]#
Compute the loss for a minibatch of data.
This function uses the outputs of the inference and generative functions to compute a loss. This many optionally include other penalty terms, which should be computed here.
This function should return an object of type
LossOutput
.- Return type:
on_load
- BaseModuleClass.on_load(model)[source]#
Callback function run in
load()
prior to loading module state dict.
sample