scvi.train.SaveCheckpoint#
- class scvi.train.SaveCheckpoint(dirpath=None, filename=None, monitor='validation_loss', load_best_on_end=False, **kwargs)[source]#
Bases:
ModelCheckpoint
BETA
Saves model checkpoints based on a monitored metric.Inherits from
ModelCheckpoint
and modifies the default behavior to save the full model state instead of just the state dict. This is necessary for compatibility withBaseModelClass
.The best model save and best model score based on
monitor
can be accessed post-training with thebest_model_path
andbest_model_score
attributes, respectively.Known issues:
Does not set
train_indices
,validation_indices
, andtest_indices
for checkpoints.Does not set
history
for checkpoints. This can be accessed in the final model however.Unsupported arguments:
save_weights_only
andsave_last
.
- Parameters:
dirpath (
str
|None
(default:None
)) – Base directory to save the model checkpoints. IfNone
, defaults to a subdirectory in :attr:scvi.settings.logging_dir
formatted with the current date, time, and monitor.filename (
str
|None
(default:None
)) – Name for the checkpoint directories, which can contain formatting options for auto-filling. IfNone
, defaults to{epoch}-{step}-{monitor}
.monitor (
str
(default:'validation_loss'
)) – Metric to monitor for checkpointing.load_best_on_end (
bool
(default:False
)) – IfTrue
, loads the best model state into the model at the end of training.**kwargs – Additional keyword arguments passed into the constructor for
ModelCheckpoint
.
Attributes table#
Methods table#
|
Saves the model state on Lightning checkpoint saves. |
|
Loads the best model state into the model at the end of training. |