scvi.module.base.LossRecorder#

class scvi.module.base.LossRecorder(loss, reconstruction_loss=None, kl_local=None, kl_global=None, **kwargs)[source]#

Bases: object

Loss signature for models.

This class provides an organized way to record the model loss, as well as the components of the ELBO. This may also be used in MLE, MAP, EM methods. The loss is used for backpropagation during inference. The other parameters are used for logging/early stopping during inference.

Parameters:
  • loss (Union[Dict[str, Union[Tensor, Array]], Tensor, Array]) – Tensor with loss for minibatch. Should be one dimensional with one value. Note that loss should be a Tensor and not the result of .item().

  • reconstruction_loss (Union[Dict[str, Union[Tensor, Array]], Tensor, Array, None] (default: None)) – Reconstruction loss for each observation in the minibatch. If a tensor, converted to a dictionary with key “reconstruction_loss” and value as tensor

  • kl_local (Union[Dict[str, Union[Tensor, Array]], Tensor, Array, None] (default: None)) – KL divergence associated with each observation in the minibatch. If a tensor, converted to a dictionary with key “kl_local” and value as tensor

  • kl_global (Union[Dict[str, Union[Tensor, Array]], Tensor, Array, None] (default: None)) – Global kl divergence term. Should be one dimensional with one value. If a tensor, converted to a dictionary with key “kl_global” and value as tensor

  • **kwargs – Additional metrics can be passed as keyword arguments and will be available as attributes of the object.

Attributes table#

kl_global

rtype:

Union[Tensor, Array]

kl_global_sum

rtype:

Union[Tensor, Array]

kl_local

rtype:

Union[Tensor, Array]

kl_local_sum

rtype:

Union[Tensor, Array]

loss

rtype:

Union[Tensor, Array]

reconstruction_loss

rtype:

Union[Tensor, Array]

reconstruction_loss_sum

rtype:

Union[Tensor, Array]

Methods table#

dict_sum(x)

Wrapper of LossOutput.dict_sum.

Attributes#

kl_global

LossRecorder.kl_global[source]#
Return type:

Union[Tensor, Array]

kl_global_sum

LossRecorder.kl_global_sum[source]#
Return type:

Union[Tensor, Array]

kl_local

LossRecorder.kl_local[source]#
Return type:

Union[Tensor, Array]

kl_local_sum

LossRecorder.kl_local_sum[source]#
Return type:

Union[Tensor, Array]

loss

LossRecorder.loss[source]#
Return type:

Union[Tensor, Array]

reconstruction_loss

LossRecorder.reconstruction_loss[source]#
Return type:

Union[Tensor, Array]

reconstruction_loss_sum

LossRecorder.reconstruction_loss_sum[source]#
Return type:

Union[Tensor, Array]

Methods#

dict_sum

LossRecorder.dict_sum(x)[source]#

Wrapper of LossOutput.dict_sum.