scvi.external.JointEmbeddingVAE#

class scvi.external.JointEmbeddingVAE(n_input, *args, joint_embedding_weight=100.0, lambda_off_diag=0.01, min_library_size=10.0, reconstruction_weight=1.0, variance_weight=0.0, use_joint_embedding=True, **kwargs)[source]#

Bases: VAE

VAE with joint embedding loss using binomial thinning and CCO [Svensson, 2026].

This module extends the standard VAE with a cross-correlation objective (CCO) loss that encourages the embedding of a thinned view to match the embedding of the original data. This promotes robustness to count dropout/noise.

Thinning probabilities are dynamically sampled per cell to produce target library sizes that are log-uniform between min_library_size and the observed library size. This matches realistic library size variation in single-cell data.

Parameters:
  • n_input (int) – Number of input features.

  • joint_embedding_weight (float (default: 100.0)) – Weight for the CCO loss. Default is 100.0.

  • lambda_off_diag (float (default: 0.01)) – Off-diagonal penalty in CCO loss. Default is 0.01.

  • min_library_size (float (default: 10.0)) – Minimum target library size for thinning. Default is 10. Thinned library sizes are sampled log-uniformly between this value and the observed library size.

  • reconstruction_weight (float (default: 1.0)) – Weight for reconstruction loss. Default is 1.0. Set to 0.0 for pure self-supervised training with only CCO loss.

  • variance_weight (float (default: 0.0)) – Weight for variance regularization loss (VICReg-style). Default is 0.0. Set to positive value (e.g., 1.0) to prevent dimension collapse in self-supervised training.

  • use_joint_embedding (bool (default: True)) – Whether to use joint embedding loss. Default is True. Set to False to train as standard SCVI.

  • **kwargs – Additional keyword arguments passed to VAE.

See also

VAE

Attributes table#

Methods table#

loss(tensors, inference_outputs, ...[, ...])

Compute the loss including optional joint embedding CCO loss.

Attributes#

JointEmbeddingVAE.training: bool#

Methods#

JointEmbeddingVAE.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0, joint_embedding_weight=None, reconstruction_weight=None)[source]#

Compute the loss including optional joint embedding CCO loss.

Parameters:
  • tensors (dict[str, Tensor]) – Dictionary of input tensors.

  • inference_outputs (dict[str, Tensor | Distribution | None]) – Dictionary of inference outputs.

  • generative_outputs (dict[str, Distribution | None]) – Dictionary of generative outputs.

  • kl_weight (float (default: 1.0)) – Weight for KL divergence term.

  • joint_embedding_weight (float | None (default: None)) – Optional override for joint embedding weight.

  • reconstruction_weight (float | None (default: None)) – Optional override for reconstruction weight. Set to 0.0 for pure self-supervised training.

Return type:

LossOutput

Returns:

LossOutput Loss output with total loss, reconstruction loss, KL terms, and CCO loss in extra_metrics.