scvi.train.JaxTrainingPlan#
- class scvi.train.JaxTrainingPlan(module, *, optimizer='Adam', optimizer_creator=None, lr=0.001, weight_decay=1e-06, eps=0.01, max_norm=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, **loss_kwargs)[source]#
Bases:
TrainingPlan
Lightning module task to train Pyro scvi-tools modules.
- Parameters:
module (JaxBaseModuleClass) – An instance of
JaxModuleWraper
.optimizer (Literal['Adam', 'AdamW', 'Custom']) – One of “Adam”, “AdamW”, or “Custom”, which requires a custom optimizer creator callable to be passed via
optimizer_creator
.optimizer_creator (Optional[Callable[[], GradientTransformation]]) – A callable returning a
GradientTransformation
. This allows using any optax optimizer with custom hyperparameters.lr (float) – Learning rate used for optimization, when
optimizer_creator
is None.weight_decay (float) – Weight decay used in optimization, when
optimizer_creator
is None.eps (float) – eps used for optimization, when
optimizer_creator
is None.max_norm (Optional[float]) – Max global norm of gradients for gradient clipping.
n_steps_kl_warmup (Optional[int]) – Number of training steps (minibatches) to scale weight on KL divergences from
min_kl_weight
tomax_kl_weight
. Only activated whenn_epochs_kl_warmup
is set to None.n_epochs_kl_warmup (Optional[int]) – Number of epochs to scale weight on KL divergences from
min_kl_weight
tomax_kl_weight
. Overridesn_steps_kl_warmup
when both are notNone
.
Attributes table#
Methods table#
|
Called to perform backward on the loss returned in |
Shim optimizer for PyTorch Lightning. |
|
|
Passthrough to the module's forward method. |
Get optimizer creator for the model. |
|
|
Jit training step. |
|
Jit validation step. |
|
Override this method to adjust the default way the |
|
Set the state of the module. |
|
Training step for Jax. |
|
Bypass Pytorch Lightning device management. |
|
Validation step for Jax. |
Attributes#
training
Methods#
backward
- JaxTrainingPlan.backward(*args, **kwargs)[source]#
Called to perform backward on the loss returned in
training_step()
. Override this hook with your own implementation if you need to.- Parameters:
loss – The loss tensor returned by
training_step()
. If gradient accumulation is used, the loss here holds the normalized value (scaled by 1 / accumulation steps).optimizer – Current optimizer being used.
None
if using manual optimization.optimizer_idx – Index of the current optimizer being used.
None
if using manual optimization.
Example:
def backward(self, loss, optimizer, optimizer_idx): loss.backward()
configure_optimizers
- JaxTrainingPlan.configure_optimizers()[source]#
Shim optimizer for PyTorch Lightning.
PyTorch Lightning wants to take steps on an optimizer returned by this function in order to increment the global step count. See PyTorch Lighinting optimizer manual loop.
Here we provide a shim optimizer that we can take steps on at minimal computational cost in order to keep Lightning happy :).
forward
get_optimizer_creator
- JaxTrainingPlan.get_optimizer_creator()[source]#
Get optimizer creator for the model.
- Return type:
Callable[[], GradientTransformation]
jit_training_step
jit_validation_step
optimizer_step
- JaxTrainingPlan.optimizer_step(*args, **kwargs)[source]#
Override this method to adjust the default way the
Trainer
calls each optimizer.By default, Lightning calls
step()
andzero_grad()
as shown in the example once per optimizer. This method (andzero_grad()
) won’t be called during the accumulation phase whenTrainer(accumulate_grad_batches != 1)
. Overriding this hook has no benefit with manual optimization.- Parameters:
epoch – Current epoch
batch_idx – Index of current batch
optimizer – A PyTorch optimizer
optimizer_idx – If you used multiple optimizers, this indexes into that list.
optimizer_closure – The optimizer closure. This closure must be executed as it includes the calls to
training_step()
,optimizer.zero_grad()
, andbackward()
.on_tpu –
True
if TPU backward is requiredusing_lbfgs – True if the matching optimizer is
torch.optim.LBFGS
Examples:
# DEFAULT def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_lbfgs): optimizer.step(closure=optimizer_closure) # Alternating schedule for optimizer steps (i.e.: GANs) def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_lbfgs): # update generator opt every step if optimizer_idx == 0: optimizer.step(closure=optimizer_closure) # update discriminator opt every 2 steps if optimizer_idx == 1: if (batch_idx + 1) % 2 == 0 : optimizer.step(closure=optimizer_closure) else: # call the closure by itself to run `training_step` + `backward` without an optimizer step optimizer_closure() # ... # add as many optimizers as you want
Here’s another example showing how to use this for more advanced things such as learning rate warm-up:
# learning rate warm-up def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_lbfgs, ): # update params optimizer.step(closure=optimizer_closure) # manually warm up lr without a scheduler if self.trainer.global_step < 500: lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0) for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.learning_rate
set_train_state
training_step
transfer_batch_to_device
- static JaxTrainingPlan.transfer_batch_to_device(batch, device, dataloader_idx)[source]#
Bypass Pytorch Lightning device management.
validation_step