scvi.train.PyroTrainingPlan.backward

PyroTrainingPlan.backward(*args, **kwargs)[source]

Override backward with your own implementation if you need to.

Parameters
loss

Loss is already scaled by accumulated grads

optimizer

Current optimizer being used

optimizer_idx

Index of the current optimizer being used

Called to perform backward step. Feel free to override as needed. The loss passed in has already been scaled for accumulated gradients if requested.

Example:

def backward(self, loss, optimizer, optimizer_idx):
    loss.backward()