scvi.external.sysvi.SysVAE#
- class scvi.external.sysvi.SysVAE(n_input, n_batch, n_continuous_cov=0, n_cats_per_cov=None, embed_categorical_covariates=False, prior='vamp', n_prior_components=5, trainable_priors=True, pseudoinput_data=None, n_latent=15, n_hidden=256, n_layers=2, dropout_rate=0.1, out_var_mode='feature', encoder_decoder_kwargs=None, embedding_kwargs=None)[source]#
Bases:
BaseModuleClass
,EmbeddingModuleMixin
CVAE with optional VampPrior and latent cycle consistency loss.
Described in Hrovatin et al. (2023).
- Parameters:
n_input (
int
) – Number of input features.n_batch (
int
) – Number of batches.n_continuous_cov (
int
(default:0
)) – Number of continuous covariates.n_cats_per_cov (
list
[int
] |None
(default:None
)) – A list of integers containing the number of categories for each categorical covariate.embed_categorical_covariates (
bool
(default:False
)) – IfTrue
embeds categorical covariates and batches into continuously-valued vectors instead of using one-hot encoding.prior (
Literal
['standard_normal'
,'vamp'
] (default:'vamp'
)) – Which prior distribution to use. *'standard_normal'
: Standard normal distribution. *'vamp'
: VampPrior.n_prior_components (
int
(default:5
)) – Number of prior components for VampPrior.trainable_priors (
bool
(default:True
)) – Should prior components of VampPrior be trainable.pseudoinput_data (
dict
[str
,Tensor
] |None
(default:None
)) – Initialisation data for VampPrior. Should match input tensors structure.n_latent (
int
(default:15
)) – Numer of latent space dimensions.n_hidden (
int
(default:256
)) – Numer of hidden nodes per layer for encoder and decoder.n_layers (
int
(default:2
)) – Number of hidden layers for encoder and decoder.dropout_rate (
float
(default:0.1
)) – Dropout rate for encoder and decoder.out_var_mode (
Literal
['sample_feature'
,'feature'
] (default:'feature'
)) – How variance is predicted in decoder, seeVarEncoder
. One of the following: *'sample_feature'
- learn variance per sample and feature. *'feature'
- learn variance per feature, constant across samples.encoder_decoder_kwargs (
dict
|None
(default:None
)) – Additional kwargs passed to encoder and decoder. Both encoder and decoder useEncoderDecoder
.embedding_kwargs (
dict
|None
(default:None
)) – Keyword arguments passed intoEmbedding
ifembed_categorical_covariates
is set toTrue
.
Attributes table#
Methods table#
|
Forward pass through the network. |
|
Generative: latent representation & covariates -> expression. |
|
Inference: expression & covariates -> latent representation. |
|
MSE loss between standardised inputs. |
|
Compute loss of forward pass. |
|
Randomly selects a new batch different from the real one for each cell. |
|
Generate expression samples from posterior generative distribution. |
Attributes#
- SysVAE.training: bool#
Methods#
- SysVAE.forward(tensors, get_inference_input_kwargs=None, get_generative_input_kwargs=None, inference_kwargs=None, generative_kwargs=None, loss_kwargs=None, compute_loss=True)[source]#
Forward pass through the network.
- Parameters:
get_inference_input_kwargs (
dict
|None
(default:None
)) – Keyword args for_get_inference_input()
get_generative_input_kwargs (
dict
|None
(default:None
)) – Keyword args for_get_generative_input()
inference_kwargs (
dict
|None
(default:None
)) – Keyword args forinference()
generative_kwargs (
dict
|None
(default:None
)) – Keyword args forgenerative()
loss_kwargs (
dict
|None
(default:None
)) – Keyword args forloss()
compute_loss (
bool
(default:True
)) – Whether to compute loss on forward pass. This adds another return value.
- Return type:
tuple
[dict
[str
,Tensor
],dict
[str
,Tensor
]] |tuple
[dict
[str
,Tensor
],dict
[str
,Tensor
],LossOutput
]
- SysVAE.generative(z, batch_index, cont_covs=None, cat_covs=None, cycle_batch=None, compute_original=True, compute_cycle=True, y=None, transform_batch=None)[source]#
Generative: latent representation & covariates -> expression.
- SysVAE.inference(x, batch_index, cont_covs=None, cat_covs=None, n_samples=1)[source]#
Inference: expression & covariates -> latent representation.
- static SysVAE.latent_cycle_consistency(qz, qz_cycle)[source]#
MSE loss between standardised inputs.
MSE loss should be computed on standardized latent representations as else model can learn to cheat the MSE loss by setting the latent representations to smaller numbers. Standardizer is fitted on concatenation of both inputs.
- Parameters:
qz (
Tensor
) – Posterior distribution from the inference pass.qz_cycle (
Tensor
) – Posterior distribution from the cycle inference pass.
- Return type:
Tensor
- SysVAE.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0, reconstruction_weight=1.0, z_distance_cycle_weight=2.0, compute_cycle=True)[source]#
Compute loss of forward pass.
- Parameters:
inference_outputs (
dict
[str
,Tensor
]) – Outputs of normal and cycle inference pass.generative_outputs (
dict
[str
,Tensor
]) – Outputs of the normal generative pass.kl_weight (
float
(default:1.0
)) – Weight for KL loss.reconstruction_weight (
float
(default:1.0
)) – Weight for reconstruction loss.z_distance_cycle_weight (
float
(default:2.0
)) – Weight for cycle loss.
- Return type:
- Returns:
Loss components: Cycle loss is added to extra metrics as
'cycle_loss'
.
- SysVAE.random_select_batch(batch)[source]#
Randomly selects a new batch different from the real one for each cell.
- Parameters:
batch (torch.Tensor) – Tensor containing the real batch index for each cell.
- Return type:
Tensor
- Returns:
torch.Tensor Tensor with newly assigned batch indices for each cell.