scvi.model.base.PyroJitGuideWarmup

Contents

scvi.model.base.PyroJitGuideWarmup#

class scvi.model.base.PyroJitGuideWarmup(dataloader=None)[source]#

A callback to warmup a Pyro guide.

This helps initialize all the relevant parameters by running one minibatch through the Pyro model.

Attributes table#

state_key

Identifier for the state of the callback.

Methods table#

load_state_dict(state_dict)

Called when loading a checkpoint, implement to reload callback state given callback's state_dict.

on_after_backward(trainer, pl_module)

Called after loss.backward() and before optimizers are stepped.

on_before_backward(trainer, pl_module, loss)

Called before loss.backward().

on_before_optimizer_step(trainer, pl_module, ...)

Called before optimizer.step().

on_before_zero_grad(trainer, pl_module, ...)

Called before optimizer.zero_grad().

on_exception(trainer, pl_module, exception)

Called when any trainer execution is interrupted by an exception.

on_fit_end(trainer, pl_module)

Called when fit ends.

on_fit_start(trainer, pl_module)

Called when fit begins.

on_load_checkpoint(trainer, pl_module, ...)

Called when loading a model checkpoint, use to reload state.

on_predict_batch_end(trainer, pl_module, ...)

Called when the predict batch ends.

on_predict_batch_start(trainer, pl_module, ...)

Called when the predict batch begins.

on_predict_end(trainer, pl_module)

Called when predict ends.

on_predict_epoch_end(trainer, pl_module)

Called when the predict epoch ends.

on_predict_epoch_start(trainer, pl_module)

Called when the predict epoch begins.

on_predict_start(trainer, pl_module)

Called when the predict begins.

on_sanity_check_end(trainer, pl_module)

Called when the validation sanity check ends.

on_sanity_check_start(trainer, pl_module)

Called when the validation sanity check starts.

on_save_checkpoint(trainer, pl_module, ...)

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

on_test_batch_end(trainer, pl_module, ...[, ...])

Called when the test batch ends.

on_test_batch_start(trainer, pl_module, ...)

Called when the test batch begins.

on_test_end(trainer, pl_module)

Called when the test ends.

on_test_epoch_end(trainer, pl_module)

Called when the test epoch ends.

on_test_epoch_start(trainer, pl_module)

Called when the test epoch begins.

on_test_start(trainer, pl_module)

Called when the test begins.

on_train_batch_end(trainer, pl_module, ...)

Called when the train batch ends.

on_train_batch_start(trainer, pl_module, ...)

Called when the train batch begins.

on_train_end(trainer, pl_module)

Called when the train ends.

on_train_epoch_end(trainer, pl_module)

Called when the train epoch ends.

on_train_epoch_start(trainer, pl_module)

Called when the train epoch begins.

on_train_start(trainer, pl_module)

Way to warmup Pyro Guide in an automated way.

on_validation_batch_end(trainer, pl_module, ...)

Called when the validation batch ends.

on_validation_batch_start(trainer, ...[, ...])

Called when the validation batch begins.

on_validation_end(trainer, pl_module)

Called when the validation loop ends.

on_validation_epoch_end(trainer, pl_module)

Called when the val epoch ends.

on_validation_epoch_start(trainer, pl_module)

Called when the val epoch begins.

on_validation_start(trainer, pl_module)

Called when the validation loop begins.

setup(trainer, pl_module, stage)

Called when fit, validate, test, predict, or tune begins.

state_dict()

Called when saving a checkpoint, implement to generate callback's state_dict.

teardown(trainer, pl_module, stage)

Called when fit, validate, test, predict, or tune ends.

Attributes#

PyroJitGuideWarmup.state_key[source]#

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.

Methods#

PyroJitGuideWarmup.load_state_dict(state_dict)[source]#

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (Dict[str, Any]) – the callback state returned by state_dict.

Return type:

None

PyroJitGuideWarmup.on_after_backward(trainer, pl_module)[source]#

Called after loss.backward() and before optimizers are stepped.

Return type:

None

PyroJitGuideWarmup.on_before_backward(trainer, pl_module, loss)[source]#

Called before loss.backward().

Return type:

None

PyroJitGuideWarmup.on_before_optimizer_step(trainer, pl_module, optimizer)[source]#

Called before optimizer.step().

Return type:

None

PyroJitGuideWarmup.on_before_zero_grad(trainer, pl_module, optimizer)[source]#

Called before optimizer.zero_grad().

Return type:

None

PyroJitGuideWarmup.on_exception(trainer, pl_module, exception)[source]#

Called when any trainer execution is interrupted by an exception.

Return type:

None

PyroJitGuideWarmup.on_fit_end(trainer, pl_module)[source]#

Called when fit ends.

Return type:

None

PyroJitGuideWarmup.on_fit_start(trainer, pl_module)[source]#

Called when fit begins.

Return type:

None

PyroJitGuideWarmup.on_load_checkpoint(trainer, pl_module, checkpoint)[source]#

Called when loading a model checkpoint, use to reload state.

Parameters:
Return type:

None

PyroJitGuideWarmup.on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]#

Called when the predict batch ends.

Return type:

None

PyroJitGuideWarmup.on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]#

Called when the predict batch begins.

Return type:

None

PyroJitGuideWarmup.on_predict_end(trainer, pl_module)[source]#

Called when predict ends.

Return type:

None

PyroJitGuideWarmup.on_predict_epoch_end(trainer, pl_module)[source]#

Called when the predict epoch ends.

Return type:

None

PyroJitGuideWarmup.on_predict_epoch_start(trainer, pl_module)[source]#

Called when the predict epoch begins.

Return type:

None

PyroJitGuideWarmup.on_predict_start(trainer, pl_module)[source]#

Called when the predict begins.

Return type:

None

PyroJitGuideWarmup.on_sanity_check_end(trainer, pl_module)[source]#

Called when the validation sanity check ends.

Return type:

None

PyroJitGuideWarmup.on_sanity_check_start(trainer, pl_module)[source]#

Called when the validation sanity check starts.

Return type:

None

PyroJitGuideWarmup.on_save_checkpoint(trainer, pl_module, checkpoint)[source]#

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:
Return type:

None

PyroJitGuideWarmup.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]#

Called when the test batch ends.

Return type:

None

PyroJitGuideWarmup.on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]#

Called when the test batch begins.

Return type:

None

PyroJitGuideWarmup.on_test_end(trainer, pl_module)[source]#

Called when the test ends.

Return type:

None

PyroJitGuideWarmup.on_test_epoch_end(trainer, pl_module)[source]#

Called when the test epoch ends.

Return type:

None

PyroJitGuideWarmup.on_test_epoch_start(trainer, pl_module)[source]#

Called when the test epoch begins.

Return type:

None

PyroJitGuideWarmup.on_test_start(trainer, pl_module)[source]#

Called when the test begins.

Return type:

None

PyroJitGuideWarmup.on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#

Called when the train batch ends. :rtype: None

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

PyroJitGuideWarmup.on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]#

Called when the train batch begins.

Return type:

None

PyroJitGuideWarmup.on_train_end(trainer, pl_module)[source]#

Called when the train ends.

Return type:

None

PyroJitGuideWarmup.on_train_epoch_end(trainer, pl_module)[source]#

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
Return type:

None

PyroJitGuideWarmup.on_train_epoch_start(trainer, pl_module)[source]#

Called when the train epoch begins.

Return type:

None

PyroJitGuideWarmup.on_train_start(trainer, pl_module)[source]#

Way to warmup Pyro Guide in an automated way.

Also device agnostic.

PyroJitGuideWarmup.on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]#

Called when the validation batch ends.

Return type:

None

PyroJitGuideWarmup.on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]#

Called when the validation batch begins.

Return type:

None

PyroJitGuideWarmup.on_validation_end(trainer, pl_module)[source]#

Called when the validation loop ends.

Return type:

None

PyroJitGuideWarmup.on_validation_epoch_end(trainer, pl_module)[source]#

Called when the val epoch ends.

Return type:

None

PyroJitGuideWarmup.on_validation_epoch_start(trainer, pl_module)[source]#

Called when the val epoch begins.

Return type:

None

PyroJitGuideWarmup.on_validation_start(trainer, pl_module)[source]#

Called when the validation loop begins.

Return type:

None

PyroJitGuideWarmup.setup(trainer, pl_module, stage)[source]#

Called when fit, validate, test, predict, or tune begins.

Return type:

None

PyroJitGuideWarmup.state_dict()[source]#

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type:

Dict[str, Any]

Returns:

A dictionary containing callback state.

PyroJitGuideWarmup.teardown(trainer, pl_module, stage)[source]#

Called when fit, validate, test, predict, or tune ends.

Return type:

None