scvi.train.Trainer#
- class scvi.train.Trainer(accelerator=None, devices=None, benchmark=True, check_val_every_n_epoch=None, max_epochs=400, default_root_dir=None, enable_checkpointing=False, checkpointing_monitor='validation_loss', num_sanity_val_steps=0, enable_model_summary=False, early_stopping=False, early_stopping_monitor='elbo_validation', early_stopping_min_delta=0.0, early_stopping_patience=45, early_stopping_mode='min', enable_progress_bar=True, progress_bar_refresh_rate=1, simple_progress_bar=True, logger=None, log_every_n_steps=10, learning_rate_monitor=False, **kwargs)[source]#
Bases:
Trainer
Lightweight wrapper of Pytorch Lightning Trainer.
Appropriate defaults are set for scvi-tools models, as well as callbacks like EarlyStopping, with parameters accessible here.
- Parameters:
accelerator (
str
|Accelerator
|None
(default:None
)) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.devices (
list
[int
] |str
|int
|None
(default:None
)) – The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value-1
to indicate all available devices should be used, or"auto"
for automatic selection based on the chosen accelerator. Default:"auto"
.benchmark (
bool
(default:True
)) – If true enables cudnn.benchmark, which improves speed when inputs are fixed sizecheck_val_every_n_epoch (
int
|None
(default:None
)) – Check val every n train epochs. By default, val is not checked, unless early_stopping is True.max_epochs (
int
(default:400
)) – Stop training once this number of epochs is reached.default_root_dir (
str
|None
(default:None
)) – Default path for logs and weights when no logger/ckpt_callback passed. Defaults to scvi.settings.logging_dir. Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’enable_checkpointing (
bool
(default:False
)) – IfTrue
, enables checkpointing with a defaultSaveCheckpoint
callback if there is no user-definedSaveCheckpoint
incallbacks
.checkpointing_monitor (
str
(default:'validation_loss'
)) – Ifenable_checkpointing
isTrue
, specifies the metric to monitor for checkpointing.num_sanity_val_steps (
int
(default:0
)) – Sanity check runs n validation batches before starting the training routine. Set it to -1 to run all batches in all validation dataloaders.enable_model_summary (
bool
(default:False
)) – Whether to enable or disable the model summarization.early_stopping (
bool
(default:False
)) – Whether to perform early stopping with respect to the validation set. This automatically adds aEarlyStopping
instance. A custom instance can be passed by using the callbacks argument and setting this to False.early_stopping_monitor (
Literal
['elbo_validation'
,'reconstruction_loss_validation'
,'kl_local_validation'
] (default:'elbo_validation'
)) – Metric logged during validation set epoch. The available metrics will depend on the training plan class used. We list the most common options here in the typing.early_stopping_min_delta (
float
(default:0.0
)) – Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.early_stopping_patience (
int
(default:45
)) – Number of validation epochs with no improvement after which training will be stopped.early_stopping_mode (
Literal
['min'
,'max'
] (default:'min'
)) – In ‘min’ mode, training will stop when the quantity monitored has stopped decreasing and in ‘max’ mode it will stop when the quantity monitored has stopped increasing.enable_progress_bar (
bool
(default:True
)) – Whether to enable or disable the progress bar.progress_bar_refresh_rate (
int
(default:1
)) – How often to refresh progress bar (in steps). Value 0 disables progress bar.simple_progress_bar (
bool
(default:True
)) – Use custom scvi-tools simple progress bar (per epoch rather than per batch). When False, uses default PyTorch Lightning progress bar, unless enable_progress_bar is False.logger (
Logger
|None
|bool
(default:None
)) – A valid pytorch lightning logger. Defaults to a simple dictionary logger. If True, defaults to the default pytorch lightning logger.log_every_n_steps (
int
(default:10
)) – How often to log within steps. This does not affect epoch-level logging.**kwargs – Other keyword args for
Trainer
Attributes table#
Methods table#
|
Fit the model. |