scvi.module.base.SupervisedModuleClass#

class scvi.module.base.SupervisedModuleClass[source]#

Bases: object

General purpose supervised classify and loss calculations methods.

Methods table#

classification_loss(labelled_dataset)

classify(x[, batch_index, cont_covs, ...])

Forward pass through the encoder and classifier.

on_load(model, **kwargs)

Methods#

SupervisedModuleClass.classification_loss(labelled_dataset)[source]#
Return type:

tuple[Tensor, Tensor, Tensor]

SupervisedModuleClass.classify(x, batch_index=None, cont_covs=None, cat_covs=None, use_posterior_mean=True)[source]#

Forward pass through the encoder and classifier.

Parameters:
  • x (Tensor) – Tensor of shape (n_obs, n_vars).

  • batch_index (Tensor | None (default: None)) – Tensor of shape (n_obs,) denoting batch indices.

  • cont_covs (Tensor | None (default: None)) – Tensor of shape (n_obs, n_continuous_covariates).

  • cat_covs (Tensor | None (default: None)) – Tensor of shape (n_obs, n_categorical_covariates).

  • use_posterior_mean (bool (default: True)) – Whether to use the posterior mean of the latent distribution for classification.

Return type:

Tensor

Returns:

Tensor of shape (n_obs, n_labels) denoting logit scores per label. Before v1.1, this method by default returned probabilities per label, see #2301 for more details.

SupervisedModuleClass.on_load(model, **kwargs)[source]#