scvi.external.JointEmbeddingVAE#
- class scvi.external.JointEmbeddingVAE(n_input, *args, joint_embedding_weight=100.0, lambda_off_diag=0.01, min_library_size=10.0, reconstruction_weight=1.0, variance_weight=0.0, use_joint_embedding=True, **kwargs)[source]#
Bases:
VAEVAE with joint embedding loss using binomial thinning and CCO [Svensson, 2026].
This module extends the standard VAE with a cross-correlation objective (CCO) loss that encourages the embedding of a thinned view to match the embedding of the original data. This promotes robustness to count dropout/noise.
Thinning probabilities are dynamically sampled per cell to produce target library sizes that are log-uniform between min_library_size and the observed library size. This matches realistic library size variation in single-cell data.
- Parameters:
n_input (
int) – Number of input features.joint_embedding_weight (
float(default:100.0)) – Weight for the CCO loss. Default is 100.0.lambda_off_diag (
float(default:0.01)) – Off-diagonal penalty in CCO loss. Default is 0.01.min_library_size (
float(default:10.0)) – Minimum target library size for thinning. Default is 10. Thinned library sizes are sampled log-uniformly between this value and the observed library size.reconstruction_weight (
float(default:1.0)) – Weight for reconstruction loss. Default is 1.0. Set to 0.0 for pure self-supervised training with only CCO loss.variance_weight (
float(default:0.0)) – Weight for variance regularization loss (VICReg-style). Default is 0.0. Set to positive value (e.g., 1.0) to prevent dimension collapse in self-supervised training.use_joint_embedding (
bool(default:True)) – Whether to use joint embedding loss. Default is True. Set to False to train as standard SCVI.**kwargs – Additional keyword arguments passed to
VAE.
See also
Attributes table#
Methods table#
|
Compute the loss including optional joint embedding CCO loss. |
Attributes#
- JointEmbeddingVAE.training: bool#
Methods#
- JointEmbeddingVAE.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0, joint_embedding_weight=None, reconstruction_weight=None)[source]#
Compute the loss including optional joint embedding CCO loss.
- Parameters:
inference_outputs (
dict[str,Tensor|Distribution|None]) – Dictionary of inference outputs.generative_outputs (
dict[str,Distribution|None]) – Dictionary of generative outputs.kl_weight (
float(default:1.0)) – Weight for KL divergence term.joint_embedding_weight (
float|None(default:None)) – Optional override for joint embedding weight.reconstruction_weight (
float|None(default:None)) – Optional override for reconstruction weight. Set to 0.0 for pure self-supervised training.
- Return type:
- Returns:
LossOutput Loss output with total loss, reconstruction loss, KL terms, and CCO loss in extra_metrics.