scvi.module.VAEC#
- class scvi.module.VAEC(n_input, n_labels=0, n_hidden=128, n_latent=5, n_layers=2, log_variational=True, ct_weight=None, dropout_rate=0.05, **module_kwargs)[source]#
Bases:
BaseModuleClass
Conditional Variational auto-encoder model.
This is an implementation of the CondSCVI model
- Parameters:
- n_input :
int
Number of input genes
- n_labels :
int
(default:0
) Number of labels
- n_hidden :
int
(default:128
) Number of nodes per hidden layer
- n_latent :
int
(default:5
) Dimensionality of the latent space
- n_layers :
int
(default:2
) Number of hidden layers used for encoder and decoder NNs
- dropout_rate :
float
(default:0.05
) Dropout rate for the encoder and decoder neural network
- log_variational :
bool
(default:True
) Log(data+1) prior to encoding for numerical stability. Not normalization.
- n_input :
Attributes table#
Methods table#
|
Runs the generative model. |
|
High level inference method. |
|
Compute the loss for a minibatch of data. |
|
Generate observation samples from the posterior predictive distribution. |
Attributes#
T_destination#
alias of TypeVar(‘T_destination’, bound=Mapping
[str
, Tensor
])
.. autoattribute:: VAEC.T_destination
device
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- VAEC.device#
dump_patches#
- VAEC.dump_patches: bool = False#
This allows better BC support for
load_state_dict()
. Instate_dict()
, the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See_load_from_state_dict
on how to use this information in loading.If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.
training#
Methods#
generative#
inference#
loss#
- VAEC.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0)[source]#
Compute the loss for a minibatch of data.
This function uses the outputs of the inference and generative functions to compute a loss. This many optionally include other penalty terms, which should be computed here.
This function should return an object of type
LossRecorder
.
sample#
- VAEC.sample(tensors, n_samples=1)[source]#
Generate observation samples from the posterior predictive distribution.
The posterior predictive distribution is written as \(p(\hat{x} \mid x)\).
- Parameters:
- tensors
Tensors dict
- n_samples
Number of required samples for each cell
- Return type:
- Returns:
x_new :
torch.Tensor
tensor with shape (n_cells, n_genes, n_samples)