scvi.external.cellassign.CellAssignModule#
- class scvi.external.cellassign.CellAssignModule(n_genes, rho, basis_means, b_g_0=None, random_b_g_0=True, n_batch=0, n_cats_per_cov=None, n_continuous_cov=0)[source]#
Bases:
scvi.module.base._base_module.BaseModuleClass
Model for CellAssign.
- Parameters
- n_genes :
int
Number of input genes
- n_labels
Number of input cell types
- rho :
Tensor
Binary matrix of cell type markers
- basis_means :
Tensor
Basis means numpy array
- b_g_0 :
Tensor
|None
Optional
[Tensor
] (default:None
) Base gene expression tensor. If None, use randomly initialized b_g_0.
- random_b_g_0 :
bool
(default:True
) Override to enforce randomly initialized b_g_0. If True, use random default, if False defaults to b_g_0.
- n_batch :
int
(default:0
) Number of batches, if 0, no batch correction is performed.
- n_cats_per_cov :
Iterable
[int
] |None
Optional
[Iterable
[int
]] (default:None
) Number of categories for each extra categorical covariate
- n_continuous_cov :
int
(default:0
) Number of continuous covariates
- n_genes :
Attributes table#
Methods table#
|
Run the generative model. |
Run the inference (recognition) model. |
|
|
Compute the loss for a minibatch of data. |
|
Generate samples from the learned model. |
Attributes#
T_destination#
- CellAssignModule.T_destination#
alias of TypeVar(‘T_destination’, bound=
Mapping
[str
,torch.Tensor
])
alias of TypeVar(‘T_destination’, bound=Mapping
[str
, torch.Tensor
])
.. autoattribute:: CellAssignModule.T_destination
device
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- CellAssignModule.device#
dump_patches#
- CellAssignModule.dump_patches: bool = False#
This allows better BC support for
load_state_dict()
. Instate_dict()
, the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See_load_from_state_dict
on how to use this information in loading.If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.
training#
Methods#
generative#
- CellAssignModule.generative(x, size_factor, design_matrix=None)[source]#
Run the generative model.
This function should return the parameters associated with the likelihood of the data. This is typically written as \(p(x|z)\).
This function should return a dictionary with str keys and
Tensor
values.
inference#
- CellAssignModule.inference()[source]#
Run the inference (recognition) model.
In the case of variational inference, this function will perform steps related to computing variational distribution parameters. In a VAE, this will involve running data through encoder networks.
This function should return a dictionary with str keys and
Tensor
values.
loss#
- CellAssignModule.loss(tensors, inference_outputs, generative_outputs, n_obs=1.0)[source]#
Compute the loss for a minibatch of data.
This function uses the outputs of the inference and generative functions to compute a loss. This many optionally include other penalty terms, which should be computed here.
This function should return an object of type
LossRecorder
.