scvi.train.PyroTrainingPlan.optimizer_step

PyroTrainingPlan.optimizer_step(*args, **kwargs)[source]

Override this method to adjust the default way the Trainer calls each optimizer. By default, Lightning calls step() and zero_grad() as shown in the example once per optimizer.

Warning

If you are overriding this method, make sure that you pass the optimizer_closure parameter to optimizer.step() function as shown in the examples. This ensures that training_step(), optimizer.zero_grad(), backward() are called within run_training_batch().

Parameters
epoch

Current epoch

batch_idx

Index of current batch

optimizer

A PyTorch optimizer

optimizer_idx

If you used multiple optimizers, this indexes into that list.

optimizer_closure

Closure for all optimizers

on_tpu

True if TPU backward is required

using_native_amp

True if using native amp

using_lbfgs

True if the matching optimizer is torch.optim.LBFGS

Examples:

# DEFAULT
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                   optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
    optimizer.step(closure=optimizer_closure)

# Alternating schedule for optimizer steps (i.e.: GANs)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                   optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
    # update generator opt every step
    if optimizer_idx == 0:
        optimizer.step(closure=optimizer_closure)

    # update discriminator opt every 2 steps
    if optimizer_idx == 1:
        if (batch_idx + 1) % 2 == 0 :
            optimizer.step(closure=optimizer_closure)

    # ...
    # add as many optimizers as you want

Here’s another example showing how to use this for more advanced things such as learning rate warm-up:

# learning rate warm-up
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                   optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
    # warm up lr
    if self.trainer.global_step < 500:
        lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
        for pg in optimizer.param_groups:
            pg['lr'] = lr_scale * self.learning_rate

    # update params
    optimizer.step(closure=optimizer_closure)