scvi.module.base.LossOutput#

class scvi.module.base.LossOutput(loss, reconstruction_loss=None, kl_local=None, kl_global=None, classification_loss=None, logits=None, true_labels=None, extra_metrics=<factory>, n_obs_minibatch=None, reconstruction_loss_sum=None, kl_local_sum=None, kl_global_sum=None)[source]#

Bases: object

Loss signature for models.

This class provides an organized way to record the model loss, as well as the components of the ELBO. This may also be used in MLE, MAP, EM methods. The loss is used for backpropagation during inference. The other parameters are used for logging/early stopping during inference.

Parameters:
  • loss (Union[dict[str, Union[Tensor, Array]], Tensor, Array]) – Tensor with loss for minibatch. Should be one dimensional with one value. Note that loss should be in an array/tensor and not a float.

  • reconstruction_loss (Union[dict[str, Union[Tensor, Array]], Tensor, Array, None] (default: None)) – Reconstruction loss for each observation in the minibatch. If a tensor, converted to a dictionary with key “reconstruction_loss” and value as tensor.

  • kl_local (Union[dict[str, Union[Tensor, Array]], Tensor, Array, None] (default: None)) – KL divergence associated with each observation in the minibatch. If a tensor, converted to a dictionary with key “kl_local” and value as tensor.

  • kl_global (Union[dict[str, Union[Tensor, Array]], Tensor, Array, None] (default: None)) – Global KL divergence term. Should be one dimensional with one value. If a tensor, converted to a dictionary with key “kl_global” and value as tensor.

  • classification_loss (Union[dict[str, Union[Tensor, Array]], Tensor, Array, None] (default: None)) – Classification loss.

  • logits (Union[Tensor, Array, None] (default: None)) – Logits for classification.

  • true_labels (Union[Tensor, Array, None] (default: None)) – True labels for classification.

  • extra_metrics (dict[str, Union[Tensor, Array]] | None (default: <factory>)) – Additional metrics can be passed as arrays/tensors or dictionaries of arrays/tensors.

  • n_obs_minibatch (int | None (default: None)) – Number of observations in the minibatch. If None, will be inferred from the shape of the reconstruction_loss tensor.

Examples

>>> loss_output = LossOutput(
...     loss=loss,
...     reconstruction_loss=reconstruction_loss,
...     kl_local=kl_local,
...     extra_metrics={"x": scalar_tensor_x, "y": scalar_tensor_y},
... )

Attributes table#

classification_loss

extra_metrics_keys

Keys for extra metrics.

kl_global

kl_global_sum

kl_local

kl_local_sum

logits

n_obs_minibatch

reconstruction_loss

reconstruction_loss_sum

true_labels

loss

extra_metrics

Methods table#

dict_sum(dictionary)

Sum over elements of a dictionary.

replace(**updates)

"Returns a new object replacing the specified fields with new values.

Attributes#

LossOutput.classification_loss: Union[dict[str, Union[Tensor, Array]], Tensor, Array, None] = None#
LossOutput.extra_metrics_keys[source]#

Keys for extra metrics.

LossOutput.kl_global: Union[dict[str, Union[Tensor, Array]], Tensor, Array, None] = None#
LossOutput.kl_global_sum: Union[Tensor, Array] = None#
LossOutput.kl_local: Union[dict[str, Union[Tensor, Array]], Tensor, Array, None] = None#
LossOutput.kl_local_sum: Union[Tensor, Array] = None#
LossOutput.logits: Union[Tensor, Array, None] = None#
LossOutput.n_obs_minibatch: int | None = None#
LossOutput.reconstruction_loss: Union[dict[str, Union[Tensor, Array]], Tensor, Array, None] = None#
LossOutput.reconstruction_loss_sum: Union[Tensor, Array] = None#
LossOutput.true_labels: Union[Tensor, Array, None] = None#
LossOutput.loss: Union[dict[str, Union[Tensor, Array]], Tensor, Array]#
LossOutput.extra_metrics: dict[str, Union[Tensor, Array]] | None#

Methods#

static LossOutput.dict_sum(dictionary)[source]#

Sum over elements of a dictionary.

LossOutput.replace(**updates)[source]#

“Returns a new object replacing the specified fields with new values.