scvi.train.SaveCheckpoint#

class scvi.train.SaveCheckpoint(dirpath=None, filename=None, monitor='validation_loss', load_best_on_end=False, check_nan_gradients=False, **kwargs)[source]#

Bases: ModelCheckpoint

BETA Saves model checkpoints based on a monitored metric.

Inherits from ModelCheckpoint and modifies the default behavior to save the full model state instead of just the state dict. This is necessary for compatibility with BaseModelClass.

The best model save and best model score based on monitor can be accessed post-training with the best_model_path and best_model_score attributes, respectively.

Known issues:

  • Does not set train_indices, validation_indices, and test_indices for checkpoints.

  • Does not set history for checkpoints. This can be accessed in the final model however.

  • Unsupported arguments: save_weights_only and save_last.

Parameters:
  • dirpath (str | None (default: None)) – Base directory to save the model checkpoints. If None, defaults to a subdirectory in :attr:scvi.settings.logging_dir formatted with the current date, time, and monitor.

  • filename (str | None (default: None)) – Name for the checkpoint directories, which can contain formatting options for auto-filling. If None, defaults to {epoch}-{step}-{monitor}.

  • monitor (str (default: 'validation_loss')) – Metric to monitor for checkpointing.

  • load_best_on_end (bool (default: False)) – If True, loads the best model state into the model at the end of training.

  • check_nan_gradients (bool (default: False)) – If True, will use the on exception callback to store best model in case of training exception caused by NaN’s in gradients or loss calculations.

  • **kwargs – Additional keyword arguments passed into the constructor for ModelCheckpoint.

Attributes table#

Methods table#

on_exception(trainer, pl_module, exception)

Save the model in case of unexpected exceptions, like Nan in loss or gradients

on_save_checkpoint(trainer, *args)

Saves the model state on Lightning checkpoint saves.

on_train_batch_end(trainer, pl_module, ...)

Save checkpoint on train batch end if we meet the criteria for every_n_train_steps

on_train_end(trainer, pl_module)

Loads the best model state into the model at the end of training.

Attributes#

Methods#

SaveCheckpoint.on_exception(trainer, pl_module, exception)[source]#

Save the model in case of unexpected exceptions, like Nan in loss or gradients

Return type:

None

SaveCheckpoint.on_save_checkpoint(trainer, *args)[source]#

Saves the model state on Lightning checkpoint saves.

Return type:

None

SaveCheckpoint.on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#

Save checkpoint on train batch end if we meet the criteria for every_n_train_steps

Return type:

None

SaveCheckpoint.on_train_end(trainer, pl_module)[source]#

Loads the best model state into the model at the end of training.

Return type:

None