scvi.module.base.BaseModuleClass#

class scvi.module.base.BaseModuleClass[source]#

Bases: Module

Abstract class for scvi-tools modules.

Attributes table#

Methods table#

forward(tensors[, ...])

Forward pass through the network.

generative(*args, **kwargs)

Run the generative model.

inference(*args, **kwargs)

Run the inference (recognition) model.

loss(*args, **kwargs)

Compute the loss for a minibatch of data.

on_load(model)

Callback function run in :method:`~scvi.model.base.BaseModelClass.load` prior to loading module state dict.

sample(*args, **kwargs)

Generate samples from the learned model.

Attributes#

T_destination#

BaseModuleClass.T_destination#

alias of TypeVar(‘T_destination’, bound=Dict[str, Any])

alias of TypeVar(‘T_destination’, bound=Dict[str, Any]) .. autoattribute:: BaseModuleClass.T_destination device ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

BaseModuleClass.device[source]#

dump_patches#

BaseModuleClass.dump_patches: bool = False#

training#

BaseModuleClass.training: bool#

Methods#

forward#

BaseModuleClass.forward(tensors, get_inference_input_kwargs=None, get_generative_input_kwargs=None, inference_kwargs=None, generative_kwargs=None, loss_kwargs=None, compute_loss=True)[source]#

Forward pass through the network.

Parameters:
tensors

tensors to pass through

get_inference_input_kwargs : dict | NoneOptional[dict] (default: None)

Keyword args for _get_inference_input()

get_generative_input_kwargs : dict | NoneOptional[dict] (default: None)

Keyword args for _get_generative_input()

inference_kwargs : dict | NoneOptional[dict] (default: None)

Keyword args for inference()

generative_kwargs : dict | NoneOptional[dict] (default: None)

Keyword args for generative()

loss_kwargs : dict | NoneOptional[dict] (default: None)

Keyword args for loss()

compute_loss

Whether to compute loss on forward pass. This adds another return value.

Return type:

Tuple[Tensor, Tensor] | Tuple[Tensor, Tensor, LossRecorder]Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, LossRecorder]]

generative#

abstract BaseModuleClass.generative(*args, **kwargs)[source]#

Run the generative model.

This function should return the parameters associated with the likelihood of the data. This is typically written as \(p(x|z)\).

This function should return a dictionary with str keys and Tensor values.

Return type:

{str: Tensor | Distribution}Dict[str, Union[Tensor, Distribution]]

inference#

abstract BaseModuleClass.inference(*args, **kwargs)[source]#

Run the inference (recognition) model.

In the case of variational inference, this function will perform steps related to computing variational distribution parameters. In a VAE, this will involve running data through encoder networks.

This function should return a dictionary with str keys and Tensor values.

Return type:

{str: Tensor | Distribution}Dict[str, Union[Tensor, Distribution]]

loss#

abstract BaseModuleClass.loss(*args, **kwargs)[source]#

Compute the loss for a minibatch of data.

This function uses the outputs of the inference and generative functions to compute a loss. This many optionally include other penalty terms, which should be computed here.

This function should return an object of type LossRecorder.

Return type:

LossRecorder

on_load#

BaseModuleClass.on_load(model)[source]#

Callback function run in :method:`~scvi.model.base.BaseModelClass.load` prior to loading module state dict.

sample#

abstract BaseModuleClass.sample(*args, **kwargs)[source]#

Generate samples from the learned model.