scvi.train.JaxTrainingPlan#

class scvi.train.JaxTrainingPlan(module, *, optimizer='Adam', optimizer_creator=None, lr=0.001, weight_decay=1e-06, eps=0.01, max_norm=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, **loss_kwargs)[source]#

Bases: TrainingPlan

Lightning module task to train Pyro scvi-tools modules.

Parameters:
  • module (JaxBaseModuleClass) – An instance of JaxModuleWraper.

  • optimizer (Literal['Adam', 'AdamW', 'Custom']) – One of “Adam”, “AdamW”, or “Custom”, which requires a custom optimizer creator callable to be passed via optimizer_creator.

  • optimizer_creator (Optional[Callable[[], GradientTransformation]]) – A callable returning a GradientTransformation. This allows using any optax optimizer with custom hyperparameters.

  • lr (float) – Learning rate used for optimization, when optimizer_creator is None.

  • weight_decay (float) – Weight decay used in optimization, when optimizer_creator is None.

  • eps (float) – eps used for optimization, when optimizer_creator is None.

  • max_norm (Optional[float]) – Max global norm of gradients for gradient clipping.

  • n_steps_kl_warmup (Optional[int]) – Number of training steps (minibatches) to scale weight on KL divergences from min_kl_weight to max_kl_weight. Only activated when n_epochs_kl_warmup is set to None.

  • n_epochs_kl_warmup (Optional[int]) – Number of epochs to scale weight on KL divergences from min_kl_weight to max_kl_weight. Overrides n_steps_kl_warmup when both are not None.

Attributes table#

Methods table#

backward(*args, **kwargs)

Called to perform backward on the loss returned in training_step().

configure_optimizers()

Shim optimizer for PyTorch Lightning.

forward(*args, **kwargs)

Passthrough to the module's forward method.

get_optimizer_creator()

Get optimizer creator for the model.

jit_training_step(state, batch, rngs, **kwargs)

Jit training step.

jit_validation_step(state, batch, rngs, **kwargs)

Jit validation step.

optimizer_step(*args, **kwargs)

Override this method to adjust the default way the Trainer calls each optimizer.

set_train_state(params[, state])

Set the state of the module.

training_step(batch, batch_idx)

Training step for Jax.

transfer_batch_to_device(batch, device, ...)

Bypass Pytorch Lightning device management.

validation_step(batch, batch_idx)

Validation step for Jax.

Attributes#

training

JaxTrainingPlan.training: bool#

Methods#

backward

JaxTrainingPlan.backward(*args, **kwargs)[source]#

Called to perform backward on the loss returned in training_step(). Override this hook with your own implementation if you need to.

Parameters:
  • loss – The loss tensor returned by training_step(). If gradient accumulation is used, the loss here holds the normalized value (scaled by 1 / accumulation steps).

  • optimizer – Current optimizer being used. None if using manual optimization.

  • optimizer_idx – Index of the current optimizer being used. None if using manual optimization.

Example:

def backward(self, loss, optimizer, optimizer_idx):
    loss.backward()

configure_optimizers

JaxTrainingPlan.configure_optimizers()[source]#

Shim optimizer for PyTorch Lightning.

PyTorch Lightning wants to take steps on an optimizer returned by this function in order to increment the global step count. See PyTorch Lighinting optimizer manual loop.

Here we provide a shim optimizer that we can take steps on at minimal computational cost in order to keep Lightning happy :).

forward

JaxTrainingPlan.forward(*args, **kwargs)[source]#

Passthrough to the module’s forward method.

get_optimizer_creator

JaxTrainingPlan.get_optimizer_creator()[source]#

Get optimizer creator for the model.

Return type:

Callable[[], GradientTransformation]

jit_training_step

static JaxTrainingPlan.jit_training_step(state, batch, rngs, **kwargs)[source]#

Jit training step.

Parameters:

jit_validation_step

JaxTrainingPlan.jit_validation_step(state, batch, rngs, **kwargs)[source]#

Jit validation step.

Parameters:

optimizer_step

JaxTrainingPlan.optimizer_step(*args, **kwargs)[source]#

Override this method to adjust the default way the Trainer calls each optimizer.

By default, Lightning calls step() and zero_grad() as shown in the example once per optimizer. This method (and zero_grad()) won’t be called during the accumulation phase when Trainer(accumulate_grad_batches != 1). Overriding this hook has no benefit with manual optimization.

Parameters:
  • epoch – Current epoch

  • batch_idx – Index of current batch

  • optimizer – A PyTorch optimizer

  • optimizer_idx – If you used multiple optimizers, this indexes into that list.

  • optimizer_closure – The optimizer closure. This closure must be executed as it includes the calls to training_step(), optimizer.zero_grad(), and backward().

  • on_tpuTrue if TPU backward is required

  • using_lbfgs – True if the matching optimizer is torch.optim.LBFGS

Examples:

# DEFAULT
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                   optimizer_closure, on_tpu, using_lbfgs):
    optimizer.step(closure=optimizer_closure)

# Alternating schedule for optimizer steps (i.e.: GANs)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                   optimizer_closure, on_tpu, using_lbfgs):
    # update generator opt every step
    if optimizer_idx == 0:
        optimizer.step(closure=optimizer_closure)

    # update discriminator opt every 2 steps
    if optimizer_idx == 1:
        if (batch_idx + 1) % 2 == 0 :
            optimizer.step(closure=optimizer_closure)
        else:
            # call the closure by itself to run `training_step` + `backward` without an optimizer step
            optimizer_closure()

    # ...
    # add as many optimizers as you want

Here’s another example showing how to use this for more advanced things such as learning rate warm-up:

# learning rate warm-up
def optimizer_step(
    self,
    epoch,
    batch_idx,
    optimizer,
    optimizer_idx,
    optimizer_closure,
    on_tpu,
    using_lbfgs,
):
    # update params
    optimizer.step(closure=optimizer_closure)

    # manually warm up lr without a scheduler
    if self.trainer.global_step < 500:
        lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0)
        for pg in optimizer.param_groups:
            pg["lr"] = lr_scale * self.learning_rate

set_train_state

JaxTrainingPlan.set_train_state(params, state=None)[source]#

Set the state of the module.

training_step

JaxTrainingPlan.training_step(batch, batch_idx)[source]#

Training step for Jax.

transfer_batch_to_device

static JaxTrainingPlan.transfer_batch_to_device(batch, device, dataloader_idx)[source]#

Bypass Pytorch Lightning device management.

validation_step

JaxTrainingPlan.validation_step(batch, batch_idx)[source]#

Validation step for Jax.