scvi.train.Trainer#
- class scvi.train.Trainer(accelerator=None, devices=None, benchmark=True, check_val_every_n_epoch=None, max_epochs=400, default_root_dir=None, enable_checkpointing=False, num_sanity_val_steps=0, enable_model_summary=False, early_stopping=False, early_stopping_monitor='elbo_validation', early_stopping_min_delta=0.0, early_stopping_patience=45, early_stopping_mode='min', enable_progress_bar=True, progress_bar_refresh_rate=1, simple_progress_bar=True, logger=None, log_every_n_steps=10, replace_sampler_ddp=True, **kwargs)[source]#
Bases:
TunableMixin
,Trainer
Lightweight wrapper of Pytorch Lightning Trainer.
Appropriate defaults are set for scvi-tools models, as well as callbacks like EarlyStopping, with parameters accessible here.
- Parameters:
accelerator (Optional[Union[str, Accelerator]]) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.
devices (Optional[Union[List[int], str, int]]) – The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value
-1
to indicate all available devices should be used, or"auto"
for automatic selection based on the chosen accelerator. Default:"auto"
.benchmark (bool) – If true enables cudnn.benchmark, which improves speed when inputs are fixed size
check_val_every_n_epoch (Optional[int]) – Check val every n train epochs. By default, val is not checked, unless
early_stopping
isTrue
.max_epochs (Tunable_[int]) – Stop training once this number of epochs is reached.
default_root_dir (Optional[str]) – Default path for logs and weights when no logger/ckpt_callback passed. Defaults to
scvi.settings.logging_dir
. Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’enable_checkpointing (bool) – If
True
, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint incallbacks
.num_sanity_val_steps (int) – Sanity check runs n validation batches before starting the training routine. Set it to -1 to run all batches in all validation dataloaders.
enable_model_summary (bool) – Whether to enable or disable the model summarization.
early_stopping (bool) – Whether to perform early stopping with respect to the validation set. This automatically adds a
EarlyStopping
instance. A custom instance can be passed by using the callbacks argument and setting this toFalse
.early_stopping_monitor (Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation']) – Metric logged during validation set epoch. The available metrics will depend on the training plan class used. We list the most common options here in the typing.
early_stopping_min_delta (float) – Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.
early_stopping_patience (int) – Number of validation epochs with no improvement after which training will be stopped.
early_stopping_mode (Literal['min', 'max']) – In ‘min’ mode, training will stop when the quantity monitored has stopped decreasing and in ‘max’ mode it will stop when the quantity monitored has stopped increasing.
enable_progress_bar (bool) – Whether to enable or disable the progress bar.
progress_bar_refresh_rate (int) – How often to refresh progress bar (in steps). Value 0 disables progress bar.
simple_progress_bar (bool) – Use custom scvi-tools simple progress bar (per epoch rather than per batch). When
False
, uses default PyTorch Lightning progress bar, unlessenable_progress_bar
isFalse
.logger (Union[Logger, None, bool]) – A valid pytorch lightning logger. Defaults to a simple dictionary logger. If
True
, defaults to the default pytorch lightning logger.log_every_n_steps (int) – How often to log within steps. This does not affect epoch-level logging.
replace_sampler_ddp (bool) – Explicitly enables or disables sampler replacement. If
True
, by default it will add shuffle=True for train sampler and shuffle=False for val/test sampler. If you want to customize it, you can set replace_sampler_ddp=False and add your own distributed sampler.**kwargs – Other keyword args for
Trainer
Attributes table#
Methods table#
|
Fit the model. |
Attributes#
Methods#
fit