scvi.train.JaxTrainingPlan#
- class scvi.train.JaxTrainingPlan(module, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, optim_kwargs=None, **loss_kwargs)[source]#
Bases:
LightningModule
Lightning module task to train Pyro scvi-tools modules.
- Parameters:
module (
JaxModuleWrapper
) – An instance ofJaxModuleWraper
.n_steps_kl_warmup (
Optional
[int
] (default:None
)) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.n_epochs_kl_warmup (
Optional
[int
] (default:400
)) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.
Attributes table#
Scaling factor on KL divergence during training. |
Methods table#
|
Called to perform backward on the loss returned in |
Configure optimizers. |
|
|
Same as |
|
Jit training step. |
|
Jit validation step. |
|
Override this method to adjust the default way the |
|
Set the state of the module. |
|
Training step for Jax. |
|
Bypass Pytorch Lightning device management. |
|
Validation step for Jax. |
Attributes#
kl_weight
training
Methods#
backward
- JaxTrainingPlan.backward(*args, **kwargs)[source]#
Called to perform backward on the loss returned in
training_step()
. Override this hook with your own implementation if you need to.- Parameters:
loss – The loss tensor returned by
training_step()
. If gradient accumulation is used, the loss here holds the normalized value (scaled by 1 / accumulation steps).optimizer – Current optimizer being used.
None
if using manual optimization.optimizer_idx – Index of the current optimizer being used.
None
if using manual optimization.
Example:
def backward(self, loss, optimizer, optimizer_idx): loss.backward()
configure_optimizers
forward
- JaxTrainingPlan.forward(*args, **kwargs)[source]#
Same as
torch.nn.Module.forward()
.- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
jit_training_step
jit_validation_step
optimizer_step
- JaxTrainingPlan.optimizer_step(*args, **kwargs)[source]#
Override this method to adjust the default way the
Trainer
calls each optimizer.By default, Lightning calls
step()
andzero_grad()
as shown in the example once per optimizer. This method (andzero_grad()
) won’t be called during the accumulation phase whenTrainer(accumulate_grad_batches != 1)
. Overriding this hook has no benefit with manual optimization.- Parameters:
epoch – Current epoch
batch_idx – Index of current batch
optimizer – A PyTorch optimizer
optimizer_idx – If you used multiple optimizers, this indexes into that list.
optimizer_closure – The optimizer closure. This closure must be executed as it includes the calls to
training_step()
,optimizer.zero_grad()
, andbackward()
.on_tpu –
True
if TPU backward is requiredusing_native_amp –
True
if using native ampusing_lbfgs – True if the matching optimizer is
torch.optim.LBFGS
Examples:
# DEFAULT def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): optimizer.step(closure=optimizer_closure) # Alternating schedule for optimizer steps (i.e.: GANs) def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): # update generator opt every step if optimizer_idx == 0: optimizer.step(closure=optimizer_closure) # update discriminator opt every 2 steps if optimizer_idx == 1: if (batch_idx + 1) % 2 == 0 : optimizer.step(closure=optimizer_closure) else: # call the closure by itself to run `training_step` + `backward` without an optimizer step optimizer_closure() # ... # add as many optimizers as you want
Here’s another example showing how to use this for more advanced things such as learning rate warm-up:
# learning rate warm-up def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs, ): # update params optimizer.step(closure=optimizer_closure) # manually warm up lr without a scheduler if self.trainer.global_step < 500: lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0) for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.learning_rate
set_train_state
training_step
transfer_batch_to_device
- static JaxTrainingPlan.transfer_batch_to_device(batch, device, dataloader_idx)[source]#
Bypass Pytorch Lightning device management.
validation_step