class scvi.train.Trainer(accelerator=None, devices=None, benchmark=True, check_val_every_n_epoch=None, max_epochs=400, default_root_dir=None, enable_checkpointing=False, checkpointing_monitor='validation_loss', num_sanity_val_steps=0, enable_model_summary=False, early_stopping=False, early_stopping_monitor='elbo_validation', early_stopping_min_delta=0.0, early_stopping_patience=45, early_stopping_mode='min', additional_val_metrics=None, enable_progress_bar=True, progress_bar_refresh_rate=1, simple_progress_bar=True, logger=None, log_every_n_steps=10, learning_rate_monitor=False, **kwargs)[source]#

Bases: TunableMixin, Trainer

Lightweight wrapper of Pytorch Lightning Trainer.

Appropriate defaults are set for scvi-tools models, as well as callbacks like EarlyStopping, with parameters accessible here.

  • accelerator (Union[str, Accelerator, None] (default: None)) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.

  • devices (Union[list[int], str, int, None] (default: None)) – The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or "auto" for automatic selection based on the chosen accelerator. Default: "auto".

  • benchmark (bool (default: True)) – If true enables cudnn.benchmark, which improves speed when inputs are fixed size

  • check_val_every_n_epoch (Optional[int] (default: None)) – Check val every n train epochs. By default, val is not checked, unless early_stopping is True.

  • max_epochs (Tunable_[int] (default: 400)) – Stop training once this number of epochs is reached.

  • default_root_dir (Optional[str] (default: None)) – Default path for logs and weights when no logger/ckpt_callback passed. Defaults to scvi.settings.logging_dir. Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’

  • enable_checkpointing (bool (default: False)) – If True, enables checkpointing with a default SaveCheckpoint callback if there is no user-defined SaveCheckpoint in callbacks.

  • checkpointing_monitor (str (default: 'validation_loss')) – If enable_checkpointing is True, specifies the metric to monitor for checkpointing.

  • num_sanity_val_steps (int (default: 0)) – Sanity check runs n validation batches before starting the training routine. Set it to -1 to run all batches in all validation dataloaders.

  • enable_model_summary (bool (default: False)) – Whether to enable or disable the model summarization.

  • early_stopping (bool (default: False)) – Whether to perform early stopping with respect to the validation set. This automatically adds a EarlyStopping instance. A custom instance can be passed by using the callbacks argument and setting this to False.

  • early_stopping_monitor (Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation'] (default: 'elbo_validation')) – Metric logged during validation set epoch. The available metrics will depend on the training plan class used. We list the most common options here in the typing.

  • early_stopping_min_delta (float (default: 0.0)) – Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.

  • early_stopping_patience (int (default: 45)) – Number of validation epochs with no improvement after which training will be stopped.

  • early_stopping_mode (Literal['min', 'max'] (default: 'min')) – In ‘min’ mode, training will stop when the quantity monitored has stopped decreasing and in ‘max’ mode it will stop when the quantity monitored has stopped increasing.

  • additional_val_metrics (Union[Callable[[BaseModelClass], float], list[Callable[[BaseModelClass], float]], dict[str, Callable[[BaseModelClass], float]], None] (default: None)) – Additional validation metrics to compute and log. See MetricsCallback for more details.

  • enable_progress_bar (bool (default: True)) – Whether to enable or disable the progress bar.

  • progress_bar_refresh_rate (int (default: 1)) – How often to refresh progress bar (in steps). Value 0 disables progress bar.

  • simple_progress_bar (bool (default: True)) – Use custom scvi-tools simple progress bar (per epoch rather than per batch). When False, uses default PyTorch Lightning progress bar, unless enable_progress_bar is False.

  • logger (Union[Logger, None, bool] (default: None)) – A valid pytorch lightning logger. Defaults to a simple dictionary logger. If True, defaults to the default pytorch lightning logger.

  • log_every_n_steps (int (default: 10)) – How often to log within steps. This does not affect epoch-level logging.

  • **kwargs – Other keyword args for Trainer

Attributes table#

Methods table#

fit(*args, **kwargs)

Fit the model.


Methods#*args, **kwargs)[source]#

Fit the model.