scvi.external.diagvi.DIAGVAE#
- class scvi.external.diagvi.DIAGVAE(n_inputs, n_batches, n_labels, modality_likelihoods, normalize_lib, guidance_graph, use_gmm_prior, semi_supervised, n_mixture_components, n_latent, n_hidden, n_layers, dropout_rate)[source]#
Bases:
BaseModuleClassVariational autoencoder module for DIAGVI multi-modal integration.
Supports GMM priors, semi-supervised classification, and flexible modality-specific decoders.
- Parameters:
n_inputs (
dict[str,int]) – Number of input features for each modality.n_batches (
dict[str,int]) – Number of batches for each modality.n_labels (
dict[str,int]) – Number of labels/classes for each modality.modality_likelihoods (
dict[str,str]) – Likelihood model for each modality. One of: ‘nb’, ‘zinb’, ‘normal’, ‘log1pnormal’, ‘zig’, ‘nbmixture’.normalize_lib (
dict[str,bool]) – Whether to normalize counts with library size in the model for each modality.guidance_graph (
Data) – Graph object encoding feature correspondences.use_gmm_prior (
dict[str,bool]) – Whether to use a GMM prior for each modality.semi_supervised (
dict[str,bool]) – Whether to use semi-supervised classification for each modality.n_mixture_components (
dict[str,int]) – Number of mixture components for the GMM prior for each modality. If semi_supervised is True, this parameter is ignored and set to the number of unique labels in labels_key.n_latent (
int) – Dimensionality of the latent space.n_hidden (
int) – Number of nodes per hidden layer.n_layers (
int) – Number of hidden layers.dropout_rate (
float) – Dropout rate for encoders.
Attributes table#
Methods table#
|
Run the graph encoder, moving the guidance graph to device on first call. |
|
Run the generative (decoder) model for a given modality. |
|
Run the inference (encoder and graph) step for a given modality. |
|
Compute the loss for a batch. |
|
Sample from the generative model for each modality. |
Attributes#
- DIAGVAE.training: bool#
Methods#
- DIAGVAE.encode_graph(device)[source]#
Run the graph encoder, moving the guidance graph to device on first call.
- DIAGVAE.generative(z, library, batch_index=None, y=None, v=None, mode=None)[source]#
Run the generative (decoder) model for a given modality.
- Parameters:
z (
Tensor) – Latent variable tensor.library (
Tensor) – Library size tensor.batch_index (
Tensor|None(default:None)) – Batch index tensor.y (
Tensor|None(default:None)) – Cell type labels (for semi-supervised GMM prior).v (
Tensor|None(default:None)) – Feature embedding from the graph encoder.
- Return type:
- Returns:
Dictionary containing generative model outputs and distributions.
- DIAGVAE.inference(x, mode=None, v_all=None, mu_all=None, logvar_all=None, deterministic=False)[source]#
Run the inference (encoder and graph) step for a given modality.
- Parameters:
x (
Tensor) – Input data tensor.v_all (
Tensor|None(default:None)) – Pre-computed graph node embeddings. If None, computed from graph encoder.mu_all (
Tensor|None(default:None)) – Pre-computed graph encoder means. Required when v_all is provided.logvar_all (
Tensor|None(default:None)) – Pre-computed graph encoder log-variances. Required when v_all is provided.deterministic (
bool(default:False)) – Whether to use the mean instead of a sample.
- Return type:
- Returns:
Dictionary of inference outputs, including latent variables and graph embeddings.
- DIAGVAE.loss(tensors, inference_outputs, generative_outputs, lam_kl, lam_data, mode=None)[source]#
Compute the loss for a batch.
- Parameters:
- Return type:
- Returns:
Object containing loss components and metrics.