class scvi.external.cellassign.CellAssignModule(n_genes, rho, basis_means, b_g_0=None, random_b_g_0=True, n_batch=0, n_cats_per_cov=None, n_continuous_cov=0)[source]#

Bases: scvi.module.base._base_module.BaseModuleClass

Model for CellAssign.

n_genes : int

Number of input genes


Number of input cell types

rho : Tensor

Binary matrix of cell type markers

basis_means : Tensor

Basis means numpy array

b_g_0 : Tensor | NoneOptional[Tensor] (default: None)

Base gene expression tensor. If None, use randomly initialized b_g_0.

random_b_g_0 : bool (default: True)

Override to enforce randomly initialized b_g_0. If True, use random default, if False defaults to b_g_0.

n_batch : int (default: 0)

Number of batches, if 0, no batch correction is performed.

n_cats_per_cov : Iterable[int] | NoneOptional[Iterable[int]] (default: None)

Number of categories for each extra categorical covariate

n_continuous_cov : int (default: 0)

Number of continuous covariates

Attributes table#

Methods table#

generative(x, size_factor[, design_matrix])

Run the generative model.


Run the inference (recognition) model.

loss(tensors, inference_outputs, ...[, n_obs])

Compute the loss for a minibatch of data.

sample(tensors[, n_samples, library_size])

Generate samples from the learned model.




alias of TypeVar(‘T_destination’, bound=Mapping[str, torch.Tensor])

alias of TypeVar(‘T_destination’, bound=Mapping[str, torch.Tensor]) .. autoattribute:: CellAssignModule.T_destination device ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^



CellAssignModule.dump_patches: bool = False#

This allows better BC support for load_state_dict(). In state_dict(), the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See _load_from_state_dict on how to use this information in loading.

If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.

training# bool#



CellAssignModule.generative(x, size_factor, design_matrix=None)[source]#

Run the generative model.

This function should return the parameters associated with the likelihood of the data. This is typically written as \(p(x|z)\).

This function should return a dictionary with str keys and Tensor values.



Run the inference (recognition) model.

In the case of variational inference, this function will perform steps related to computing variational distribution parameters. In a VAE, this will involve running data through encoder networks.

This function should return a dictionary with str keys and Tensor values.


CellAssignModule.loss(tensors, inference_outputs, generative_outputs, n_obs=1.0)[source]#

Compute the loss for a minibatch of data.

This function uses the outputs of the inference and generative functions to compute a loss. This many optionally include other penalty terms, which should be computed here.

This function should return an object of type LossRecorder.


CellAssignModule.sample(tensors, n_samples=1, library_size=1)[source]#

Generate samples from the learned model.