scvi.module.VAEC#

class scvi.module.VAEC(n_input, n_labels=0, n_hidden=128, n_latent=5, n_layers=2, dropout_rate=0.1, log_variational=True, ct_weight=None, **module_kwargs)[source]#

Conditional Variational auto-encoder model.

This is an implementation of the CondSCVI model

Parameters
n_input : int

Number of input genes

n_labels : int (default: 0)

Number of labels

n_hidden : int (default: 128)

Number of nodes per hidden layer

n_latent : int (default: 5)

Dimensionality of the latent space

n_layers : int (default: 2)

Number of hidden layers used for encoder and decoder NNs

dropout_rate : float (default: 0.1)

Dropout rate for the encoder neural network

log_variational : bool (default: True)

Log(data+1) prior to encoding for numerical stability. Not normalization.

Methods table#

 generative(z, library, y) Runs the generative model. inference(x, y[, n_samples]) High level inference method. loss(tensors, inference_outputs, ...[, ...]) Compute the loss for a minibatch of data. sample(tensors[, n_samples]) Generate observation samples from the posterior predictive distribution.

Attributes#

T_destination#

VAEC.T_destination#

alias of TypeVar(‘T_destination’, bound=Mapping[str, torch.Tensor])

alias of TypeVar(‘T_destination’, bound=Mapping[str, torch.Tensor]) .. autoattribute:: VAEC.T_destination device ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

VAEC.device#

dump_patches#

VAEC.dump_patches: bool = False#

This allows better BC support for load_state_dict(). In state_dict(), the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See _load_from_state_dict on how to use this information in loading.

If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.

training#

VAEC.training: bool#

Methods#

generative#

VAEC.generative(z, library, y)[source]#

Runs the generative model.

inference#

VAEC.inference(x, y, n_samples=1)[source]#

High level inference method.

Runs the inference (encoder) model.

loss#

VAEC.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0)[source]#

Compute the loss for a minibatch of data.

This function uses the outputs of the inference and generative functions to compute a loss. This many optionally include other penalty terms, which should be computed here.

This function should return an object of type LossRecorder.

sample#

VAEC.sample(tensors, n_samples=1)[source]#

Generate observation samples from the posterior predictive distribution.

The posterior predictive distribution is written as $$p(\hat{x} \mid x)$$.

Parameters
tensors

Tensors dict

n_samples

Number of required samples for each cell

Return type

ndarray

Returns

x_new : torch.Tensor tensor with shape (n_cells, n_genes, n_samples)