class scvi.train.SaveCheckpoint(dirpath=None, filename=None, monitor='validation_loss', **kwargs)[source]#

Bases: ModelCheckpoint

EXPERIMENTAL Saves model checkpoints based on a monitored metric.

Inherits from ModelCheckpoint and modifies the default behavior to save the full model state instead of just the state dict. This is necessary for compatibility with BaseModelClass.

The best model save and best model score based on monitor can be accessed post-training with the best_model_path and best_model_score attributes, respectively.

Known issues:

  • Does not set train_indices, validation_indices, and test_indices for checkpoints.

  • Does not set history for checkpoints. This can be accessed in the final model however.

  • Unsupported arguments: save_weights_only and save_last

  • dirpath (str | None (default: None)) – Base directory to save the model checkpoints. If None, defaults to a directory formatted with the current date, time, and monitor within settings.logging_dir.

  • filename (str | None (default: None)) – Name of the checkpoint directories. Can contain formatting options to be auto-filled. If None, defaults to {epoch}-{step}-{monitor}.

  • monitor (str (default: 'validation_loss')) – Metric to monitor for checkpointing.

  • **kwargs – Additional keyword arguments passed into ModelCheckpoint.

Attributes table#

Methods table#

on_save_checkpoint(trainer, *args)

Saves the model state on Lightning checkpoint saves.



SaveCheckpoint.on_save_checkpoint(trainer, *args)[source]#

Saves the model state on Lightning checkpoint saves.

Return type: