scvi.model.base.PyroJitGuideWarmup#
- class scvi.model.base.PyroJitGuideWarmup(dataloader=None)[source]#
A callback to warmup a Pyro guide.
This helps initialize all the relevant parameters by running one minibatch through the Pyro model.
Attributes table#
Identifier for the state of the callback. |
Methods table#
|
Called when loading a checkpoint, implement to reload callback state given callback's |
|
Called after |
|
Deprecated since version v1.6. |
|
Deprecated since version v1.6. |
|
Deprecated since version v1.6. |
|
Called before |
|
Called before |
|
Called before |
|
Deprecated since version v1.6. |
|
Deprecated since version v1.6. |
|
Deprecated since version v1.6. |
|
Called when any trainer execution is interrupted by an exception. |
|
Called when fit ends. |
|
Called when fit begins. |
|
Deprecated since version v1.6. |
|
Deprecated since version v1.6. |
|
Called when loading a model checkpoint, use to reload state. |
|
Called when the predict batch ends. |
|
Called when the predict batch begins. |
|
Called when predict ends. |
|
Called when the predict epoch ends. |
|
Called when the predict epoch begins. |
|
Called when the predict begins. |
|
Deprecated since version v1.6. |
|
Deprecated since version v1.6. |
|
Called when the validation sanity check ends. |
|
Called when the validation sanity check starts. |
|
Called when saving a checkpoint to give you a chance to store anything else you might want to save. |
|
Called when the test batch ends. |
|
Called when the test batch begins. |
|
Called when the test ends. |
|
Called when the test epoch ends. |
|
Called when the test epoch begins. |
|
Called when the test begins. |
|
Called when the train batch ends. |
|
Called when the train batch begins. |
|
Called when the train ends. |
|
Called when the train epoch ends. |
|
Called when the train epoch begins. |
|
Way to warmup Pyro Guide in an automated way. |
|
Called when the validation batch ends. |
|
Called when the validation batch begins. |
|
Called when the validation loop ends. |
|
Called when the val epoch ends. |
|
Called when the val epoch begins. |
|
Called when the validation loop begins. |
|
Called when fit, validate, test, predict, or tune begins. |
Called when saving a checkpoint, implement to generate callback's |
|
|
Called when fit, validate, test, predict, or tune ends. |
Attributes#
state_key
- PyroJitGuideWarmup.state_key[source]#
Identifier for the state of the callback.
Used to store and retrieve a callback’s state from the checkpoint dictionary by
checkpoint["callbacks"][state_key]
. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.- Return type:
Methods#
load_state_dict
- PyroJitGuideWarmup.load_state_dict(state_dict)[source]#
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict
.
on_after_backward
- PyroJitGuideWarmup.on_after_backward(trainer, pl_module)[source]#
Called after
loss.backward()
and before optimizers are stepped.- Return type:
on_batch_end
- PyroJitGuideWarmup.on_batch_end(trainer, pl_module)[source]#
Deprecated since version v1.6: This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
on_train_batch_end
instead.Called when the training batch ends.
- Return type:
on_batch_start
- PyroJitGuideWarmup.on_batch_start(trainer, pl_module)[source]#
Deprecated since version v1.6: This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
on_train_batch_start
instead.Called when the training batch begins.
- Return type:
on_before_accelerator_backend_setup
- PyroJitGuideWarmup.on_before_accelerator_backend_setup(trainer, pl_module)[source]#
Deprecated since version v1.6: This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
setup()
instead.Called before accelerator is being setup.
- Return type:
on_before_backward
- PyroJitGuideWarmup.on_before_backward(trainer, pl_module, loss)[source]#
Called before
loss.backward()
.- Return type:
on_before_optimizer_step
- PyroJitGuideWarmup.on_before_optimizer_step(trainer, pl_module, optimizer, opt_idx)[source]#
Called before
optimizer.step()
.- Return type:
on_before_zero_grad
- PyroJitGuideWarmup.on_before_zero_grad(trainer, pl_module, optimizer)[source]#
Called before
optimizer.zero_grad()
.- Return type:
on_configure_sharded_model
- PyroJitGuideWarmup.on_configure_sharded_model(trainer, pl_module)[source]#
Deprecated since version v1.6: This callback hook was deprecated in v1.6 and will be removed in v1.8. Use setup() instead.
Called before configure sharded model.
- Return type:
on_epoch_end
- PyroJitGuideWarmup.on_epoch_end(trainer, pl_module)[source]#
Deprecated since version v1.6: This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
on_<train/validation/test>_epoch_end
instead.Called when either of train/val/test epoch ends.
- Return type:
on_epoch_start
- PyroJitGuideWarmup.on_epoch_start(trainer, pl_module)[source]#
Deprecated since version v1.6: This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
on_<train/validation/test>_epoch_start
instead.Called when either of train/val/test epoch begins.
- Return type:
on_exception
- PyroJitGuideWarmup.on_exception(trainer, pl_module, exception)[source]#
Called when any trainer execution is interrupted by an exception.
- Return type:
on_fit_end
on_fit_start
on_init_end
- PyroJitGuideWarmup.on_init_end(trainer)[source]#
Deprecated since version v1.6: This callback hook was deprecated in v1.6 and will be removed in v1.8.
Called when the trainer initialization ends, model has not yet been set.
- Return type:
on_init_start
- PyroJitGuideWarmup.on_init_start(trainer)[source]#
Deprecated since version v1.6: This callback hook was deprecated in v1.6 and will be removed in v1.8.
Called when the trainer initialization begins, model has not yet been set.
- Return type:
on_load_checkpoint
- PyroJitGuideWarmup.on_load_checkpoint(trainer, pl_module, callback_state)[source]#
Called when loading a model checkpoint, use to reload state.
- Parameters:
trainer (
Trainer
) – the currentTrainer
instance.pl_module (
LightningModule
) – the currentLightningModule
instance.callback_state (
Dict
[str
,Any
]) – the callback state returned byon_save_checkpoint
.
Note
The
on_load_checkpoint
won’t be called with an undefined state. If youron_load_checkpoint
hook behavior doesn’t rely on a state, you will still need to overrideon_save_checkpoint
to return adummy state
.Deprecated since version v1.6: This callback hook will change its signature and behavior in v1.8. If you wish to load the state of the callback, use
Callback.load_state_dict
instead. In v1.8Callback.on_load_checkpoint(checkpoint)
will receive the entire loaded checkpoint dictionary instead of only the callback state from the checkpoint.- Return type:
on_predict_batch_end
- PyroJitGuideWarmup.on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]#
Called when the predict batch ends.
- Return type:
on_predict_batch_start
- PyroJitGuideWarmup.on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]#
Called when the predict batch begins.
- Return type:
on_predict_end
- PyroJitGuideWarmup.on_predict_end(trainer, pl_module)[source]#
Called when predict ends.
- Return type:
on_predict_epoch_end
- PyroJitGuideWarmup.on_predict_epoch_end(trainer, pl_module, outputs)[source]#
Called when the predict epoch ends.
- Return type:
on_predict_epoch_start
- PyroJitGuideWarmup.on_predict_epoch_start(trainer, pl_module)[source]#
Called when the predict epoch begins.
- Return type:
on_predict_start
- PyroJitGuideWarmup.on_predict_start(trainer, pl_module)[source]#
Called when the predict begins.
- Return type:
on_pretrain_routine_end
- PyroJitGuideWarmup.on_pretrain_routine_end(trainer, pl_module)[source]#
Deprecated since version v1.6: This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
on_fit_start
instead.Called when the pretrain routine ends.
- Return type:
on_pretrain_routine_start
- PyroJitGuideWarmup.on_pretrain_routine_start(trainer, pl_module)[source]#
Deprecated since version v1.6: This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
on_fit_start
instead.Called when the pretrain routine begins.
- Return type:
on_sanity_check_end
- PyroJitGuideWarmup.on_sanity_check_end(trainer, pl_module)[source]#
Called when the validation sanity check ends.
- Return type:
on_sanity_check_start
- PyroJitGuideWarmup.on_sanity_check_start(trainer, pl_module)[source]#
Called when the validation sanity check starts.
- Return type:
on_save_checkpoint
- PyroJitGuideWarmup.on_save_checkpoint(trainer, pl_module, checkpoint)[source]#
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
trainer (
Trainer
) – the currentTrainer
instance.pl_module (
LightningModule
) – the currentLightningModule
instance.checkpoint (
Dict
[str
,Any
]) – the checkpoint dictionary that will be saved.
- Return type:
- Returns:
None or the callback state. Support for returning callback state will be removed in v1.8.
Deprecated since version v1.6: Returning a value from this method was deprecated in v1.6 and will be removed in v1.8. Implement
Callback.state_dict
instead to return state. In v1.8Callback.on_save_checkpoint
can only return None.
on_test_batch_end
- PyroJitGuideWarmup.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]#
Called when the test batch ends.
- Return type:
on_test_batch_start
- PyroJitGuideWarmup.on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]#
Called when the test batch begins.
- Return type:
on_test_end
on_test_epoch_end
- PyroJitGuideWarmup.on_test_epoch_end(trainer, pl_module)[source]#
Called when the test epoch ends.
- Return type:
on_test_epoch_start
- PyroJitGuideWarmup.on_test_epoch_start(trainer, pl_module)[source]#
Called when the test epoch begins.
- Return type:
on_test_start
- PyroJitGuideWarmup.on_test_start(trainer, pl_module)[source]#
Called when the test begins.
- Return type:
on_train_batch_end
- PyroJitGuideWarmup.on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#
Called when the train batch ends.
- Return type:
on_train_batch_start
- PyroJitGuideWarmup.on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]#
Called when the train batch begins.
- Return type:
on_train_end
- PyroJitGuideWarmup.on_train_end(trainer, pl_module)[source]#
Called when the train ends.
- Return type:
on_train_epoch_end
- PyroJitGuideWarmup.on_train_epoch_end(trainer, pl_module)[source]#
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, either:
Implement training_epoch_end in the LightningModule and access outputs via the module OR
Cache data across train batch hooks inside the callback implementation to post-process in this hook.
- Return type:
on_train_epoch_start
- PyroJitGuideWarmup.on_train_epoch_start(trainer, pl_module)[source]#
Called when the train epoch begins.
- Return type:
on_train_start
- PyroJitGuideWarmup.on_train_start(trainer, pl_module)[source]#
Way to warmup Pyro Guide in an automated way.
Also device agnostic.
on_validation_batch_end
- PyroJitGuideWarmup.on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]#
Called when the validation batch ends.
- Return type:
on_validation_batch_start
- PyroJitGuideWarmup.on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]#
Called when the validation batch begins.
- Return type:
on_validation_end
- PyroJitGuideWarmup.on_validation_end(trainer, pl_module)[source]#
Called when the validation loop ends.
- Return type:
on_validation_epoch_end
- PyroJitGuideWarmup.on_validation_epoch_end(trainer, pl_module)[source]#
Called when the val epoch ends.
- Return type:
on_validation_epoch_start
- PyroJitGuideWarmup.on_validation_epoch_start(trainer, pl_module)[source]#
Called when the val epoch begins.
- Return type:
on_validation_start
- PyroJitGuideWarmup.on_validation_start(trainer, pl_module)[source]#
Called when the validation loop begins.
- Return type:
setup
- PyroJitGuideWarmup.setup(trainer, pl_module, stage=None)[source]#
Called when fit, validate, test, predict, or tune begins.
- Return type:
state_dict
- PyroJitGuideWarmup.state_dict()[source]#
Called when saving a checkpoint, implement to generate callback’s
state_dict
.
teardown