scvi.external.diagvi.DIAGVAE#

class scvi.external.diagvi.DIAGVAE(n_inputs, n_batches, n_labels, modality_likelihoods, normalize_lib, guidance_graph, use_gmm_prior, semi_supervised, n_mixture_components, n_latent, n_hidden, n_layers, dropout_rate)[source]#

Bases: BaseModuleClass

Variational autoencoder module for DIAGVI multi-modal integration.

Supports GMM priors, semi-supervised classification, and flexible modality-specific decoders.

Parameters:
  • n_inputs (dict[str, int]) – Number of input features for each modality.

  • n_batches (dict[str, int]) – Number of batches for each modality.

  • n_labels (dict[str, int]) – Number of labels/classes for each modality.

  • modality_likelihoods (dict[str, str]) – Likelihood model for each modality. One of: ‘nb’, ‘zinb’, ‘normal’, ‘log1pnormal’, ‘zig’, ‘nbmixture’.

  • normalize_lib (dict[str, bool]) – Whether to normalize counts with library size in the model for each modality.

  • guidance_graph (Data) – Graph object encoding feature correspondences.

  • use_gmm_prior (dict[str, bool]) – Whether to use a GMM prior for each modality.

  • semi_supervised (dict[str, bool]) – Whether to use semi-supervised classification for each modality.

  • n_mixture_components (dict[str, int]) – Number of mixture components for the GMM prior for each modality. If semi_supervised is True, this parameter is ignored and set to the number of unique labels in labels_key.

  • n_latent (int) – Dimensionality of the latent space.

  • n_hidden (int) – Number of nodes per hidden layer.

  • n_layers (int) – Number of hidden layers.

  • dropout_rate (float) – Dropout rate for encoders.

Attributes table#

Methods table#

encode_graph(device)

Run the graph encoder, moving the guidance graph to device on first call.

generative(z, library[, batch_index, y, v, mode])

Run the generative (decoder) model for a given modality.

inference(x[, mode, v_all, mu_all, ...])

Run the inference (encoder and graph) step for a given modality.

loss(tensors, inference_outputs, ...[, mode])

Compute the loss for a batch.

sample(tensors[, n_samples])

Sample from the generative model for each modality.

Attributes#

DIAGVAE.training: bool#

Methods#

DIAGVAE.encode_graph(device)[source]#

Run the graph encoder, moving the guidance graph to device on first call.

Parameters:

device (device) – Target device.

Return type:

tuple[Tensor, Tensor, Tensor]

Returns:

Tuple of (v_all, mu_all, logvar_all) graph embeddings.

DIAGVAE.generative(z, library, batch_index=None, y=None, v=None, mode=None)[source]#

Run the generative (decoder) model for a given modality.

Parameters:
  • z (Tensor) – Latent variable tensor.

  • library (Tensor) – Library size tensor.

  • batch_index (Tensor | None (default: None)) – Batch index tensor.

  • y (Tensor | None (default: None)) – Cell type labels (for semi-supervised GMM prior).

  • v (Tensor | None (default: None)) – Feature embedding from the graph encoder.

  • mode (str | None (default: None)) – Name of the modality.

Return type:

dict[str, Tensor]

Returns:

Dictionary containing generative model outputs and distributions.

DIAGVAE.inference(x, mode=None, v_all=None, mu_all=None, logvar_all=None, deterministic=False)[source]#

Run the inference (encoder and graph) step for a given modality.

Parameters:
  • x (Tensor) – Input data tensor.

  • mode (str | None (default: None)) – Name of the modality.

  • v_all (Tensor | None (default: None)) – Pre-computed graph node embeddings. If None, computed from graph encoder.

  • mu_all (Tensor | None (default: None)) – Pre-computed graph encoder means. Required when v_all is provided.

  • logvar_all (Tensor | None (default: None)) – Pre-computed graph encoder log-variances. Required when v_all is provided.

  • deterministic (bool (default: False)) – Whether to use the mean instead of a sample.

Return type:

dict[str, Tensor]

Returns:

Dictionary of inference outputs, including latent variables and graph embeddings.

DIAGVAE.loss(tensors, inference_outputs, generative_outputs, lam_kl, lam_data, mode=None)[source]#

Compute the loss for a batch.

Parameters:
  • tensors (dict[str, Tensor]) – Input data tensors.

  • inference_outputs (dict[str, Tensor]) – Outputs from the inference step.

  • generative_outputs (dict[str, Tensor]) – Outputs from the generative step.

  • lam_kl (Tensor | float) – Weight for the KL divergence term.

  • lam_data (Tensor | float) – Weight for the reconstruction loss term.

  • mode (str | None (default: None)) – Name of the modality.

Return type:

LossOutput

Returns:

Object containing loss components and metrics.

DIAGVAE.sample(tensors, n_samples=1)[source]#

Sample from the generative model for each modality.

Parameters:
  • tensors (dict[str, dict[str, Tensor]]) – Dictionary mapping modality names to their respective input tensors. Can contain one or both modalities.

  • n_samples (int (default: 1)) – Number of samples to generate per cell.

Return type:

dict[str, Tensor]

Returns:

Dictionary mapping modality names to sampled tensors of shape (n_samples, n_cells, n_features) if n_samples > 1, else (n_cells, n_features).