scvi.external.gimvi.JVAE#
- class scvi.external.gimvi.JVAE(dim_input_list, total_genes, indices_mappings, gene_likelihoods, model_library_bools, library_log_means, library_log_vars, n_latent=10, n_layers_encoder_individual=1, n_layers_encoder_shared=1, dim_hidden_encoder=64, n_layers_decoder_individual=0, n_layers_decoder_shared=0, dim_hidden_decoder_individual=64, dim_hidden_decoder_shared=64, dropout_rate_encoder=0.2, dropout_rate_decoder=0.2, n_batch=0, n_labels=0, dispersion='gene-batch', log_variational=True)[source]#
Bases:
scvi.module.base._base_module.BaseModuleClass
Joint variational auto-encoder for imputing missing genes in spatial data.
Implementation of gimVI [Lopez19].
- Parameters
- dim_input_list :
List
[int
] - List of number of input genes for each dataset. If
the datasets have different sizes, the dataloader will loop on the smallest until it reaches the size of the longest one
- total_genes :
int
Total number of different genes
- indices_mappings :
List
[Union
[ndarray
,slice
]] list of mapping the model inputs to the model output Eg:
[[0,2], [0,1,3,2]]
means the first dataset has 2 genes that will be reconstructed at location[0,2]
the second dataset has 4 genes that will be reconstructed at[0,1,3,2]
- gene_likelihoods :
List
[str
] list of distributions to use in the generative process ‘zinb’, ‘nb’, ‘poisson’
- list : library_log_vars np.ndarray
model or not library size with a latent variable or use observed values
- list
List of 1 x n_batch array of means of the log library sizes. Parameterizes prior on library size if not using observed library sizes.
- list
List of 1 x n_batch array of variances of the log library sizes. Parameterizes prior on library size if not using observed library sizes.
- n_latent :
int
(default:10
) dimension of latent space
- n_layers_encoder_individual :
int
(default:1
) number of individual layers in the encoder
- n_layers_encoder_shared :
int
(default:1
) number of shared layers in the encoder
- dim_hidden_encoder :
int
(default:64
) dimension of the hidden layers in the encoder
- n_layers_decoder_individual :
int
(default:0
) number of layers that are conditionally batchnormed in the encoder
- n_layers_decoder_shared :
int
(default:0
) number of shared layers in the decoder
- dim_hidden_decoder_individual :
int
(default:64
) dimension of the individual hidden layers in the decoder
- dim_hidden_decoder_shared :
int
(default:64
) dimension of the shared hidden layers in the decoder
- dropout_rate_encoder :
float
(default:0.2
) dropout encoder
- dropout_rate_decoder :
float
(default:0.2
) dropout decoder
- n_batch :
int
(default:0
) total number of batches
- n_labels :
int
(default:0
) total number of labels
- dispersion :
str
(default:'gene-batch'
) See
vae.py
- log_variational :
bool
(default:True
) Log(data+1) prior to encoding for numerical stability. Not normalization.
- dim_input_list :
Attributes table#
Methods table#
|
Run the generative model. |
|
|
|
Run the inference (recognition) model. |
|
Return the reconstruction loss and the Kullback divergences. |
|
|
|
Sample the tensor of library sizes from the posterior. |
|
Sample tensor of latent values from the posterior. |
|
Returns the tensor of scaled frequencies of expression. |
|
Return the tensor of predicted frequencies of expression. |
Attributes#
T_destination#
- JVAE.T_destination#
alias of TypeVar(‘T_destination’, bound=
Mapping
[str
,torch.Tensor
])
alias of TypeVar(‘T_destination’, bound=Mapping
[str
, torch.Tensor
])
.. autoattribute:: JVAE.T_destination
device
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- JVAE.device#
dump_patches#
- JVAE.dump_patches: bool = False#
This allows better BC support for
load_state_dict()
. Instate_dict()
, the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See_load_from_state_dict
on how to use this information in loading.If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.
training#
Methods#
generative#
- JVAE.generative(z, library, batch_index=None, y=None, mode=None)[source]#
Run the generative model.
This function should return the parameters associated with the likelihood of the data. This is typically written as \(p(x|z)\).
This function should return a dictionary with str keys and
Tensor
values.- Return type
get_sample_rate#
inference#
- JVAE.inference(x, mode=None)[source]#
Run the inference (recognition) model.
In the case of variational inference, this function will perform steps related to computing variational distribution parameters. In a VAE, this will involve running data through encoder networks.
This function should return a dictionary with str keys and
Tensor
values.- Return type
loss#
- JVAE.loss(tensors, inference_outputs, generative_outputs, mode=None, kl_weight=1.0)[source]#
Return the reconstruction loss and the Kullback divergences.
- Parameters
- x
tensor of values with shape
(batch_size, n_input)
or(batch_size, n_input_fish)
depending on the mode- batch_index
array that indicates which batch the cells belong to with shape
batch_size
- y
tensor of cell-types labels with shape (batch_size, n_labels)
- mode :
int
|None
Optional
[int
] (default:None
) indicates which head/tail to use in the joint network
- Return type
- Returns
the reconstruction loss and the Kullback divergences
reconstruction_loss#
sample_from_posterior_l#
sample_from_posterior_z#
sample_rate#
- JVAE.sample_rate(x, mode, batch_index, y=None, deterministic=False, decode_mode=None)[source]#
Returns the tensor of scaled frequencies of expression.
- Parameters
- x :
Tensor
tensor of values with shape
(batch_size, n_input)
or(batch_size, n_input_fish)
depending on the mode- y :
Tensor
|None
Optional
[Tensor
] (default:None
) tensor of cell-types labels with shape
(batch_size, n_labels)
- mode :
int
int encode mode (which input head to use in the model)
- batch_index :
Tensor
array that indicates which batch the cells belong to with shape
batch_size
- deterministic :
bool
(default:False
) bool - whether to sample or not
- decode_mode :
int
|None
Optional
[int
] (default:None
) int use to a decode mode different from encoding mode
- x :
- Return type
- Returns
type tensor of means of the scaled frequencies
sample_scale#
- JVAE.sample_scale(x, mode, batch_index, y=None, deterministic=False, decode_mode=None)[source]#
Return the tensor of predicted frequencies of expression.
- Parameters
- x :
Tensor
tensor of values with shape
(batch_size, n_input)
or(batch_size, n_input_fish)
depending on the mode- mode :
int
int encode mode (which input head to use in the model)
- batch_index :
Tensor
array that indicates which batch the cells belong to with shape
batch_size
- y :
Tensor
|None
Optional
[Tensor
] (default:None
) tensor of cell-types labels with shape
(batch_size, n_labels)
- deterministic :
bool
(default:False
) bool - whether to sample or not
- decode_mode :
int
|None
Optional
[int
] (default:None
) int use to a decode mode different from encoding mode
- x :
- Return type
- Returns
type tensor of predicted expression