scvi.external.sysvi.SysVAE#

class scvi.external.sysvi.SysVAE(n_input, n_batch, n_continuous_cov=0, n_cats_per_cov=None, embed_categorical_covariates=False, prior='vamp', n_prior_components=5, trainable_priors=True, pseudoinput_data=None, n_latent=15, n_hidden=256, n_layers=2, dropout_rate=0.1, out_var_mode='feature', encoder_decoder_kwargs=None, embedding_kwargs=None)[source]#

Bases: BaseModuleClass, EmbeddingModuleMixin

CVAE with optional VampPrior and latent cycle consistency loss.

Described in Hrovatin et al. (2023).

Parameters:
  • n_input (int) – Number of input features.

  • n_batch (int) – Number of batches.

  • n_continuous_cov (int (default: 0)) – Number of continuous covariates.

  • n_cats_per_cov (list[int] | None (default: None)) – A list of integers containing the number of categories for each categorical covariate.

  • embed_categorical_covariates (bool (default: False)) – If True embeds categorical covariates and batches into continuously-valued vectors instead of using one-hot encoding.

  • prior (Literal['standard_normal', 'vamp'] (default: 'vamp')) – Which prior distribution to use. * 'standard_normal': Standard normal distribution. * 'vamp': VampPrior.

  • n_prior_components (int (default: 5)) – Number of prior components for VampPrior.

  • trainable_priors (bool (default: True)) – Should prior components of VampPrior be trainable.

  • pseudoinput_data (dict[str, Tensor] | None (default: None)) – Initialisation data for VampPrior. Should match input tensors structure.

  • n_latent (int (default: 15)) – Numer of latent space dimensions.

  • n_hidden (int (default: 256)) – Numer of hidden nodes per layer for encoder and decoder.

  • n_layers (int (default: 2)) – Number of hidden layers for encoder and decoder.

  • dropout_rate (float (default: 0.1)) – Dropout rate for encoder and decoder.

  • out_var_mode (Literal['sample_feature', 'feature'] (default: 'feature')) – How variance is predicted in decoder, see VarEncoder. One of the following: * 'sample_feature' - learn variance per sample and feature. * 'feature' - learn variance per feature, constant across samples.

  • encoder_decoder_kwargs (dict | None (default: None)) – Additional kwargs passed to encoder and decoder. Both encoder and decoder use EncoderDecoder.

  • embedding_kwargs (dict | None (default: None)) – Keyword arguments passed into Embedding if embed_categorical_covariates is set to True.

Attributes table#

Methods table#

forward(tensors[, ...])

Forward pass through the network.

generative(z, batch_index[, cont_covs, ...])

Generative: latent representation & covariates -> expression.

inference(x, batch_index[, cont_covs, ...])

Inference: expression & covariates -> latent representation.

latent_cycle_consistency(qz, qz_cycle)

MSE loss between standardised inputs.

loss(tensors, inference_outputs, ...[, ...])

Compute loss of forward pass.

random_select_batch(batch)

Randomly selects a new batch different from the real one for each cell.

sample(*args, **kwargs)

Generate expression samples from posterior generative distribution.

Attributes#

SysVAE.training: bool#

Methods#

SysVAE.forward(tensors, get_inference_input_kwargs=None, get_generative_input_kwargs=None, inference_kwargs=None, generative_kwargs=None, loss_kwargs=None, compute_loss=True)[source]#

Forward pass through the network.

Parameters:
  • tensors (dict[str, Tensor]) – tensors to pass through

  • get_inference_input_kwargs (dict | None (default: None)) – Keyword args for _get_inference_input()

  • get_generative_input_kwargs (dict | None (default: None)) – Keyword args for _get_generative_input()

  • inference_kwargs (dict | None (default: None)) – Keyword args for inference()

  • generative_kwargs (dict | None (default: None)) – Keyword args for generative()

  • loss_kwargs (dict | None (default: None)) – Keyword args for loss()

  • compute_loss (bool (default: True)) – Whether to compute loss on forward pass. This adds another return value.

Return type:

tuple[dict[str, Tensor], dict[str, Tensor]] | tuple[dict[str, Tensor], dict[str, Tensor], LossOutput]

SysVAE.generative(z, batch_index, cont_covs=None, cat_covs=None, cycle_batch=None, compute_original=True, compute_cycle=True, y=None, transform_batch=None)[source]#

Generative: latent representation & covariates -> expression.

Return type:

dict[str, Tensor]

SysVAE.inference(x, batch_index, cont_covs=None, cat_covs=None, n_samples=1)[source]#

Inference: expression & covariates -> latent representation.

Return type:

dict[str, Tensor | Distribution | None]

static SysVAE.latent_cycle_consistency(qz, qz_cycle)[source]#

MSE loss between standardised inputs.

MSE loss should be computed on standardized latent representations as else model can learn to cheat the MSE loss by setting the latent representations to smaller numbers. Standardizer is fitted on concatenation of both inputs.

Parameters:
  • qz (Tensor) – Posterior distribution from the inference pass.

  • qz_cycle (Tensor) – Posterior distribution from the cycle inference pass.

Return type:

Tensor

SysVAE.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0, reconstruction_weight=1.0, z_distance_cycle_weight=2.0, compute_cycle=True)[source]#

Compute loss of forward pass.

Parameters:
  • tensors (dict[str, Tensor]) – Input tensors.

  • inference_outputs (dict[str, Tensor]) – Outputs of normal and cycle inference pass.

  • generative_outputs (dict[str, Tensor]) – Outputs of the normal generative pass.

  • kl_weight (float (default: 1.0)) – Weight for KL loss.

  • reconstruction_weight (float (default: 1.0)) – Weight for reconstruction loss.

  • z_distance_cycle_weight (float (default: 2.0)) – Weight for cycle loss.

Return type:

LossOutput

Returns:

Loss components: Cycle loss is added to extra metrics as 'cycle_loss'.

SysVAE.random_select_batch(batch)[source]#

Randomly selects a new batch different from the real one for each cell.

Parameters:

batch (torch.Tensor) – Tensor containing the real batch index for each cell.

Return type:

Tensor

Returns:

torch.Tensor Tensor with newly assigned batch indices for each cell.

SysVAE.sample(*args, **kwargs)[source]#

Generate expression samples from posterior generative distribution.

Not implemented as the use of decoded expression is not recommended for SysVI.

Raises:

NotImplementedError