scvi.train.SaveCheckpoint#
- class scvi.train.SaveCheckpoint(dirpath=None, filename=None, monitor='validation_loss', load_best_on_end=False, check_nan_gradients=False, **kwargs)[source]#
Bases:
ModelCheckpointSaves model checkpoints based on a monitored metric.
Inherits from
ModelCheckpointand modifies the default behavior to save the full model state instead of just the state dict. This is necessary for compatibility withBaseModelClass.The best model save and best model score based on
monitorcan be accessed post-training with thebest_model_pathandbest_model_scoreattributes, respectively.Known issues:
Does not set
train_indices,validation_indices, andtest_indicesfor checkpoints.Does not set
historyfor checkpoints. This can be accessed in the final model, however.Unsupported arguments:
save_weights_onlyandsave_last.
- Parameters:
dirpath (
str|None(default:None)) – Base directory to save the model checkpoints. IfNone, defaults to a subdirectory in :attr:scvi.settings.logging_dirformatted with the current date, time, and monitor.filename (
str|None(default:None)) – Name for the checkpoint directories, which can contain formatting options for auto filling. IfNone, defaults to{epoch}-{step}-{monitor}.monitor (
str(default:'validation_loss')) – Metric to monitor for checkpointing.load_best_on_end (
bool(default:False)) – IfTrue, loads the best model state into the model at the end of training.check_nan_gradients (
bool(default:False)) – IfTrue, will use the on exception callback to store the best model in case of training exception caused by NaN’s in gradients or loss calculations.**kwargs – Additional keyword arguments passed into the constructor for
ModelCheckpoint.
Attributes table#
Methods table#
|
Save the model in case of unexpected exceptions, like Nan in loss or gradients |
|
Saves the model state on Lightning checkpoint saves. |
|
Save checkpoint on train batch end if we meet the criteria for every_n_train_steps |
|
Loads the best model state into the model at the end of training. |
Attributes#
Methods#
- SaveCheckpoint.on_exception(trainer, pl_module, exception)[source]#
Save the model in case of unexpected exceptions, like Nan in loss or gradients
- Return type:
- SaveCheckpoint.on_save_checkpoint(trainer, *args)[source]#
Saves the model state on Lightning checkpoint saves.
- Return type: