scvi.external.gimvi.JVAE.loss

JVAE.loss(tensors, inference_outputs, generative_outputs, mode=None, kl_weight=1.0)[source]

Return the reconstruction loss and the Kullback divergences.

Parameters
x

tensor of values with shape (batch_size, n_input) or (batch_size, n_input_fish) depending on the mode

local_l_mean

tensor of means of the prior distribution of latent variable l with shape (batch_size, 1)

local_l_var

tensor of variances of the prior distribution of latent variable l with shape (batch_size, 1)

batch_index

array that indicates which batch the cells belong to with shape batch_size

y

tensor of cell-types labels with shape (batch_size, n_labels)

mode : int | NoneOptional[int] (default: None)

indicates which head/tail to use in the joint network

Return type

Tuple[Tensor, Tensor]Tuple[Tensor, Tensor]

Returns

the reconstruction loss and the Kullback divergences