JVAE

class scvi.models.JVAE(dim_input_list, total_genes, indices_mappings, reconstruction_losses, model_library_bools, n_latent=10, n_layers_encoder_individual=1, n_layers_encoder_shared=1, dim_hidden_encoder=128, n_layers_decoder_individual=0, n_layers_decoder_shared=0, dim_hidden_decoder_individual=32, dim_hidden_decoder_shared=128, dropout_rate_encoder=0.1, dropout_rate_decoder=0.3, n_batch=0, n_labels=0, dispersion='gene-batch', log_variational=True)[source]

Bases: torch.nn.modules.module.Module

Joint Variational auto-encoder for imputing missing genes in spatial data

Implementation of gimVI [Lopez19].

dim_input_list
List of number of input genes for each dataset. If

the datasets have different sizes, the dataloader will loop on the smallest until it reaches the size of the longest one

total_genes

Total number of different genes

indices_mappings

list of mapping the model inputs to the model output Eg: [[0,2], [0,1,3,2]] means the first dataset has 2 genes that will be reconstructed at location [0,2] the second dataset has 4 genes that will be reconstructed at [0,1,3,2]

reconstruction_losses

list of distributions to use in the generative process ‘zinb’, ‘nb’, ‘poisson’

model_library_bools bool list

model or not library size with a latent variable or use observed values

n_latent

dimension of latent space

n_layers_encoder_individual

number of individual layers in the encoder

n_layers_encoder_shared

number of shared layers in the encoder

dim_hidden_encoder

dimension of the hidden layers in the encoder

n_layers_decoder_individual

number of layers that are conditionally batchnormed in the encoder

n_layers_decoder_shared

number of shared layers in the decoder

dim_hidden_decoder_individual

dimension of the individual hidden layers in the decoder

dim_hidden_decoder_shared

dimension of the shared hidden layers in the decoder

dropout_rate_encoder

dropout encoder

dropout_rate_decoder

dropout decoder

n_batch

total number of batches

n_labels

total number of labels

dispersion

See vae.py

log_variational

Log(data+1) prior to encoding for numerical stability. Not normalization.

Methods Summary

decode(z, mode, library[, batch_index, y])

rtype

Tuple[Tensor, Tensor, Tensor, Tensor]Tuple[Tensor, Tensor, Tensor, Tensor]

encode(x, mode)

rtype

Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Tensor]Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Tensor]

forward(x, local_l_mean, local_l_var[, …])

Return the reconstruction loss and the Kullback divergences

get_sample_rate(x, batch_index, *_, **__)

reconstruction_loss(x, px_rate, px_r, …)

rtype

TensorTensor

sample_from_posterior_l(x[, mode, deterministic])

Sample the tensor of library sizes from the posterior

sample_from_posterior_z(x[, mode, deterministic])

Sample tensor of latent values from the posterior

sample_rate(x, mode, batch_index[, y, …])

Returns the tensor of scaled frequencies of expression

sample_scale(x, mode, batch_index[, y, …])

Return the tensor of predicted frequencies of expression

Methods Documentation

decode(z, mode, library, batch_index=None, y=None)[source]
Return type

Tuple[Tensor, Tensor, Tensor, Tensor]Tuple[Tensor, Tensor, Tensor, Tensor]

encode(x, mode)[source]
Return type

Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Tensor]Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Tensor]

forward(x, local_l_mean, local_l_var, batch_index=None, y=None, mode=None)[source]

Return the reconstruction loss and the Kullback divergences

Parameters
  • x (TensorTensor) – tensor of values with shape (batch_size, n_input) or (batch_size, n_input_fish) depending on the mode

  • local_l_mean (TensorTensor) – tensor of means of the prior distribution of latent variable l with shape (batch_size, 1)

  • local_l_var (TensorTensor) – tensor of variances of the prior distribution of latent variable l with shape (batch_size, 1)

  • batch_index (Tensor, NoneOptional[Tensor]) – array that indicates which batch the cells belong to with shape batch_size

  • y (Tensor, NoneOptional[Tensor]) – tensor of cell-types labels with shape (batch_size, n_labels)

  • mode (int, NoneOptional[int]) – indicates which head/tail to use in the joint network

Return type

Tuple[Tensor, Tensor]Tuple[Tensor, Tensor]

Returns

the reconstruction loss and the Kullback divergences

get_sample_rate(x, batch_index, *_, **__)[source]
reconstruction_loss(x, px_rate, px_r, px_dropout, mode)[source]
Return type

TensorTensor

sample_from_posterior_l(x, mode=None, deterministic=False)[source]

Sample the tensor of library sizes from the posterior

Parameters
  • x (TensorTensor) – tensor of values with shape (batch_size, n_input) or (batch_size, n_input_fish) depending on the mode

  • mode (int, NoneOptional[int]) – head id to use in the encoder

  • deterministic (boolbool) – bool - whether to sample or not

Return type

TensorTensor

Returns

type tensor of shape (batch_size, 1)

sample_from_posterior_z(x, mode=None, deterministic=False)[source]

Sample tensor of latent values from the posterior

Parameters
  • x (TensorTensor) – tensor of values with shape (batch_size, n_input)

  • mode (int, NoneOptional[int]) – head id to use in the encoder

  • deterministic (boolbool) – bool - whether to sample or not

Return type

TensorTensor

Returns

type tensor of shape (batch_size, n_latent)

sample_rate(x, mode, batch_index, y=None, deterministic=False, decode_mode=None)[source]

Returns the tensor of scaled frequencies of expression

Parameters
  • x (TensorTensor) – tensor of values with shape (batch_size, n_input) or (batch_size, n_input_fish) depending on the mode

  • y (Tensor, NoneOptional[Tensor]) – tensor of cell-types labels with shape (batch_size, n_labels)

  • mode (intint) – int encode mode (which input head to use in the model)

  • batch_index (TensorTensor) – array that indicates which batch the cells belong to with shape batch_size

  • deterministic (boolbool) – bool - whether to sample or not

  • decode_mode (int, NoneOptional[int]) – int use to a decode mode different from encoding mode

Return type

TensorTensor

Returns

type tensor of means of the scaled frequencies

sample_scale(x, mode, batch_index, y=None, deterministic=False, decode_mode=None)[source]

Return the tensor of predicted frequencies of expression

Parameters
  • x (TensorTensor) – tensor of values with shape (batch_size, n_input) or (batch_size, n_input_fish) depending on the mode

  • mode (intint) – int encode mode (which input head to use in the model)

  • batch_index (TensorTensor) – array that indicates which batch the cells belong to with shape batch_size

  • y (Tensor, NoneOptional[Tensor]) – tensor of cell-types labels with shape (batch_size, n_labels)

  • deterministic (boolbool) – bool - whether to sample or not

  • decode_mode (int, NoneOptional[int]) – int use to a decode mode different from encoding mode

Return type

TensorTensor

Returns

type tensor of predicted expression