scvi.train.JaxTrainingPlan#

class scvi.train.JaxTrainingPlan(module, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, optim_kwargs=None, **loss_kwargs)[source]#

Bases: LightningModule

Lightning module task to train Pyro scvi-tools modules.

Parameters:
  • module (JaxModuleWrapper) – An instance of JaxModuleWraper.

  • n_steps_kl_warmup (Optional[int] (default: None)) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

  • n_epochs_kl_warmup (Optional[int] (default: 400)) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

Attributes table#

kl_weight

Scaling factor on KL divergence during training.

Methods table#

backward(*args, **kwargs)

Called to perform backward on the loss returned in training_step().

configure_optimizers()

Configure optimizers.

forward(*args, **kwargs)

Same as torch.nn.Module.forward().

jit_training_step(state, batch, rngs, **kwargs)

Jit training step.

jit_validation_step(state, batch, rngs, **kwargs)

Jit validation step.

optimizer_step(*args, **kwargs)

Override this method to adjust the default way the Trainer calls each optimizer.

set_train_state(params[, state])

Set the state of the module.

training_step(batch, batch_idx)

Training step for Jax.

transfer_batch_to_device(batch, device, ...)

Bypass Pytorch Lightning device management.

validation_step(batch, batch_idx)

Validation step for Jax.

Attributes#

kl_weight

JaxTrainingPlan.kl_weight[source]#

Scaling factor on KL divergence during training.

training

JaxTrainingPlan.training: bool#

Methods#

backward

JaxTrainingPlan.backward(*args, **kwargs)[source]#

Called to perform backward on the loss returned in training_step(). Override this hook with your own implementation if you need to.

Parameters:
  • loss – The loss tensor returned by training_step(). If gradient accumulation is used, the loss here holds the normalized value (scaled by 1 / accumulation steps).

  • optimizer – Current optimizer being used. None if using manual optimization.

  • optimizer_idx – Index of the current optimizer being used. None if using manual optimization.

Example:

def backward(self, loss, optimizer, optimizer_idx):
    loss.backward()

configure_optimizers

JaxTrainingPlan.configure_optimizers()[source]#

Configure optimizers.

forward

JaxTrainingPlan.forward(*args, **kwargs)[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

jit_training_step

static JaxTrainingPlan.jit_training_step(state, batch, rngs, **kwargs)[source]#

Jit training step.

jit_validation_step

JaxTrainingPlan.jit_validation_step(state, batch, rngs, **kwargs)[source]#

Jit validation step.

optimizer_step

JaxTrainingPlan.optimizer_step(*args, **kwargs)[source]#

Override this method to adjust the default way the Trainer calls each optimizer.

By default, Lightning calls step() and zero_grad() as shown in the example once per optimizer. This method (and zero_grad()) won’t be called during the accumulation phase when Trainer(accumulate_grad_batches != 1). Overriding this hook has no benefit with manual optimization.

Parameters:
  • epoch – Current epoch

  • batch_idx – Index of current batch

  • optimizer – A PyTorch optimizer

  • optimizer_idx – If you used multiple optimizers, this indexes into that list.

  • optimizer_closure – The optimizer closure. This closure must be executed as it includes the calls to training_step(), optimizer.zero_grad(), and backward().

  • on_tpuTrue if TPU backward is required

  • using_native_ampTrue if using native amp

  • using_lbfgs – True if the matching optimizer is torch.optim.LBFGS

Examples:

# DEFAULT
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                   optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
    optimizer.step(closure=optimizer_closure)

# Alternating schedule for optimizer steps (i.e.: GANs)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                   optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
    # update generator opt every step
    if optimizer_idx == 0:
        optimizer.step(closure=optimizer_closure)

    # update discriminator opt every 2 steps
    if optimizer_idx == 1:
        if (batch_idx + 1) % 2 == 0 :
            optimizer.step(closure=optimizer_closure)
        else:
            # call the closure by itself to run `training_step` + `backward` without an optimizer step
            optimizer_closure()

    # ...
    # add as many optimizers as you want

Here’s another example showing how to use this for more advanced things such as learning rate warm-up:

# learning rate warm-up
def optimizer_step(
    self,
    epoch,
    batch_idx,
    optimizer,
    optimizer_idx,
    optimizer_closure,
    on_tpu,
    using_native_amp,
    using_lbfgs,
):
    # update params
    optimizer.step(closure=optimizer_closure)

    # manually warm up lr without a scheduler
    if self.trainer.global_step < 500:
        lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0)
        for pg in optimizer.param_groups:
            pg["lr"] = lr_scale * self.learning_rate

set_train_state

JaxTrainingPlan.set_train_state(params, state=None)[source]#

Set the state of the module.

training_step

JaxTrainingPlan.training_step(batch, batch_idx)[source]#

Training step for Jax.

transfer_batch_to_device

static JaxTrainingPlan.transfer_batch_to_device(batch, device, dataloader_idx)[source]#

Bypass Pytorch Lightning device management.

validation_step

JaxTrainingPlan.validation_step(batch, batch_idx)[source]#

Validation step for Jax.