scvi.model.base.PyroJitGuideWarmup#
- class scvi.model.base.PyroJitGuideWarmup(dataloader=None)[source]#
A callback to warmup a Pyro guide.
This helps initialize all the relevant parameters by running one minibatch through the Pyro model.
Attributes table#
Identifier for the state of the callback. |
Methods table#
|
Called when loading a checkpoint, implement to reload callback state given callback's |
|
Called after |
|
Called before |
|
Called before |
|
Called before |
|
Called when any trainer execution is interrupted by an exception. |
|
Called when fit ends. |
|
Called when fit begins. |
|
Called when loading a model checkpoint, use to reload state. |
|
Called when the predict batch ends. |
|
Called when the predict batch begins. |
|
Called when predict ends. |
|
Called when the predict epoch ends. |
|
Called when the predict epoch begins. |
|
Called when the predict begins. |
|
Called when the validation sanity check ends. |
|
Called when the validation sanity check starts. |
|
Called when saving a checkpoint to give you a chance to store anything else you might want to save. |
|
Called when the test batch ends. |
|
Called when the test batch begins. |
|
Called when the test ends. |
|
Called when the test epoch ends. |
|
Called when the test epoch begins. |
|
Called when the test begins. |
|
Called when the train batch ends. |
|
Called when the train batch begins. |
|
Called when the train ends. |
|
Called when the train epoch ends. |
|
Called when the train epoch begins. |
|
Way to warmup Pyro Guide in an automated way. |
|
Called when the validation batch ends. |
|
Called when the validation batch begins. |
|
Called when the validation loop ends. |
|
Called when the val epoch ends. |
|
Called when the val epoch begins. |
|
Called when the validation loop begins. |
|
Called when fit, validate, test, predict, or tune begins. |
Called when saving a checkpoint, implement to generate callback's |
|
|
Called when fit, validate, test, predict, or tune ends. |
Attributes#
state_key
- PyroJitGuideWarmup.state_key[source]#
Identifier for the state of the callback.
Used to store and retrieve a callback’s state from the checkpoint dictionary by
checkpoint["callbacks"][state_key]
. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.
Methods#
load_state_dict
- PyroJitGuideWarmup.load_state_dict(state_dict)[source]#
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict
.
on_after_backward
- PyroJitGuideWarmup.on_after_backward(trainer, pl_module)[source]#
Called after
loss.backward()
and before optimizers are stepped.- Return type:
on_before_backward
- PyroJitGuideWarmup.on_before_backward(trainer, pl_module, loss)[source]#
Called before
loss.backward()
.- Return type:
on_before_optimizer_step
- PyroJitGuideWarmup.on_before_optimizer_step(trainer, pl_module, optimizer, opt_idx)[source]#
Called before
optimizer.step()
.- Return type:
on_before_zero_grad
- PyroJitGuideWarmup.on_before_zero_grad(trainer, pl_module, optimizer)[source]#
Called before
optimizer.zero_grad()
.- Return type:
on_exception
- PyroJitGuideWarmup.on_exception(trainer, pl_module, exception)[source]#
Called when any trainer execution is interrupted by an exception.
- Return type:
on_fit_end
on_fit_start
on_load_checkpoint
- PyroJitGuideWarmup.on_load_checkpoint(trainer, pl_module, checkpoint)[source]#
Called when loading a model checkpoint, use to reload state.
- Parameters:
trainer (
Trainer
) – the currentTrainer
instance.pl_module (
LightningModule
) – the currentLightningModule
instance.checkpoint (
Dict
[str
,Any
]) – the full checkpoint dictionary that got loaded by the Trainer.
- Return type:
on_predict_batch_end
- PyroJitGuideWarmup.on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]#
Called when the predict batch ends.
- Return type:
on_predict_batch_start
- PyroJitGuideWarmup.on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]#
Called when the predict batch begins.
- Return type:
on_predict_end
- PyroJitGuideWarmup.on_predict_end(trainer, pl_module)[source]#
Called when predict ends.
- Return type:
on_predict_epoch_end
- PyroJitGuideWarmup.on_predict_epoch_end(trainer, pl_module, outputs)[source]#
Called when the predict epoch ends.
- Return type:
on_predict_epoch_start
- PyroJitGuideWarmup.on_predict_epoch_start(trainer, pl_module)[source]#
Called when the predict epoch begins.
- Return type:
on_predict_start
- PyroJitGuideWarmup.on_predict_start(trainer, pl_module)[source]#
Called when the predict begins.
- Return type:
on_sanity_check_end
- PyroJitGuideWarmup.on_sanity_check_end(trainer, pl_module)[source]#
Called when the validation sanity check ends.
- Return type:
on_sanity_check_start
- PyroJitGuideWarmup.on_sanity_check_start(trainer, pl_module)[source]#
Called when the validation sanity check starts.
- Return type:
on_save_checkpoint
- PyroJitGuideWarmup.on_save_checkpoint(trainer, pl_module, checkpoint)[source]#
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
trainer (
Trainer
) – the currentTrainer
instance.pl_module (
LightningModule
) – the currentLightningModule
instance.checkpoint (
Dict
[str
,Any
]) – the checkpoint dictionary that will be saved.
- Return type:
on_test_batch_end
- PyroJitGuideWarmup.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]#
Called when the test batch ends.
- Return type:
on_test_batch_start
- PyroJitGuideWarmup.on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]#
Called when the test batch begins.
- Return type:
on_test_end
on_test_epoch_end
- PyroJitGuideWarmup.on_test_epoch_end(trainer, pl_module)[source]#
Called when the test epoch ends.
- Return type:
on_test_epoch_start
- PyroJitGuideWarmup.on_test_epoch_start(trainer, pl_module)[source]#
Called when the test epoch begins.
- Return type:
on_test_start
- PyroJitGuideWarmup.on_test_start(trainer, pl_module)[source]#
Called when the test begins.
- Return type:
on_train_batch_end
- PyroJitGuideWarmup.on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#
Called when the train batch ends.
- Return type:
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.
on_train_batch_start
- PyroJitGuideWarmup.on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]#
Called when the train batch begins.
- Return type:
on_train_end
- PyroJitGuideWarmup.on_train_end(trainer, pl_module)[source]#
Called when the train ends.
- Return type:
on_train_epoch_end
- PyroJitGuideWarmup.on_train_epoch_end(trainer, pl_module)[source]#
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, either:
Implement training_epoch_end in the LightningModule and access outputs via the module OR
Cache data across train batch hooks inside the callback implementation to post-process in this hook.
- Return type:
on_train_epoch_start
- PyroJitGuideWarmup.on_train_epoch_start(trainer, pl_module)[source]#
Called when the train epoch begins.
- Return type:
on_train_start
- PyroJitGuideWarmup.on_train_start(trainer, pl_module)[source]#
Way to warmup Pyro Guide in an automated way.
Also device agnostic.
on_validation_batch_end
- PyroJitGuideWarmup.on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]#
Called when the validation batch ends.
- Return type:
on_validation_batch_start
- PyroJitGuideWarmup.on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]#
Called when the validation batch begins.
- Return type:
on_validation_end
- PyroJitGuideWarmup.on_validation_end(trainer, pl_module)[source]#
Called when the validation loop ends.
- Return type:
on_validation_epoch_end
- PyroJitGuideWarmup.on_validation_epoch_end(trainer, pl_module)[source]#
Called when the val epoch ends.
- Return type:
on_validation_epoch_start
- PyroJitGuideWarmup.on_validation_epoch_start(trainer, pl_module)[source]#
Called when the val epoch begins.
- Return type:
on_validation_start
- PyroJitGuideWarmup.on_validation_start(trainer, pl_module)[source]#
Called when the validation loop begins.
- Return type:
setup
- PyroJitGuideWarmup.setup(trainer, pl_module, stage)[source]#
Called when fit, validate, test, predict, or tune begins.
- Return type:
state_dict
- PyroJitGuideWarmup.state_dict()[source]#
Called when saving a checkpoint, implement to generate callback’s
state_dict
.
teardown