scvi.train.Trainer#

class scvi.train.Trainer(accelerator=None, devices=None, benchmark=True, check_val_every_n_epoch=None, max_epochs=400, default_root_dir=None, enable_checkpointing=False, num_sanity_val_steps=0, enable_model_summary=False, early_stopping=False, early_stopping_monitor='elbo_validation', early_stopping_min_delta=0.0, early_stopping_patience=45, early_stopping_mode='min', enable_progress_bar=True, progress_bar_refresh_rate=1, simple_progress_bar=True, logger=None, log_every_n_steps=10, replace_sampler_ddp=True, **kwargs)[source]#

Bases: TunableMixin, Trainer

Lightweight wrapper of Pytorch Lightning Trainer.

Appropriate defaults are set for scvi-tools models, as well as callbacks like EarlyStopping, with parameters accessible here.

Parameters:
  • accelerator (Optional[Union[str, Accelerator]]) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.

  • devices (Optional[Union[List[int], str, int]]) – The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or "auto" for automatic selection based on the chosen accelerator. Default: "auto".

  • benchmark (bool) – If true enables cudnn.benchmark, which improves speed when inputs are fixed size

  • check_val_every_n_epoch (Optional[int]) – Check val every n train epochs. By default, val is not checked, unless early_stopping is True.

  • max_epochs (Tunable_[int]) – Stop training once this number of epochs is reached.

  • default_root_dir (Optional[str]) – Default path for logs and weights when no logger/ckpt_callback passed. Defaults to scvi.settings.logging_dir. Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’

  • enable_checkpointing (bool) – If True, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in callbacks.

  • num_sanity_val_steps (int) – Sanity check runs n validation batches before starting the training routine. Set it to -1 to run all batches in all validation dataloaders.

  • enable_model_summary (bool) – Whether to enable or disable the model summarization.

  • early_stopping (bool) – Whether to perform early stopping with respect to the validation set. This automatically adds a EarlyStopping instance. A custom instance can be passed by using the callbacks argument and setting this to False.

  • early_stopping_monitor (Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation']) – Metric logged during validation set epoch. The available metrics will depend on the training plan class used. We list the most common options here in the typing.

  • early_stopping_min_delta (float) – Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.

  • early_stopping_patience (int) – Number of validation epochs with no improvement after which training will be stopped.

  • early_stopping_mode (Literal['min', 'max']) – In ‘min’ mode, training will stop when the quantity monitored has stopped decreasing and in ‘max’ mode it will stop when the quantity monitored has stopped increasing.

  • enable_progress_bar (bool) – Whether to enable or disable the progress bar.

  • progress_bar_refresh_rate (int) – How often to refresh progress bar (in steps). Value 0 disables progress bar.

  • simple_progress_bar (bool) – Use custom scvi-tools simple progress bar (per epoch rather than per batch). When False, uses default PyTorch Lightning progress bar, unless enable_progress_bar is False.

  • logger (Union[Logger, None, bool]) – A valid pytorch lightning logger. Defaults to a simple dictionary logger. If True, defaults to the default pytorch lightning logger.

  • log_every_n_steps (int) – How often to log within steps. This does not affect epoch-level logging.

  • replace_sampler_ddp (bool) – Explicitly enables or disables sampler replacement. If True, by default it will add shuffle=True for train sampler and shuffle=False for val/test sampler. If you want to customize it, you can set replace_sampler_ddp=False and add your own distributed sampler.

  • **kwargs – Other keyword args for Trainer

Attributes table#

Methods table#

fit(*args, **kwargs)

Fit the model.

Attributes#

Methods#

fit

Trainer.fit(*args, **kwargs)[source]#

Fit the model.