class scvi.train.Trainer(gpus=1, benchmark=True, flush_logs_every_n_steps=inf, check_val_every_n_epoch=None, max_epochs=400, default_root_dir=None, checkpoint_callback=False, num_sanity_val_steps=0, weights_summary=None, early_stopping=False, early_stopping_monitor='elbo_validation', early_stopping_min_delta=0.0, early_stopping_patience=45, early_stopping_mode='min', progress_bar_refresh_rate=1, simple_progress_bar=True, logger=None, log_every_n_steps=10, replace_sampler_ddp=False, **kwargs)[source]

Bases: pytorch_lightning.trainer.trainer.Trainer

Lightweight wrapper of Pytorch Lightning Trainer.

Appropriate defaults are set for scvi-tools models, as well as callbacks like EarlyStopping, with parameters accessible here.

gpus : int | strUnion[int, str] (default: 1)

Number of gpus to train on (int) or which GPUs to train on (list or str) applied per node

benchmark : boolbool (default: True)

If true enables cudnn.benchmark, which improves speed when inputs are fixed size


How often to flush logs to disk. By default, flushes after training complete.

check_val_every_n_epoch : int | NoneOptional[int] (default: None)

Check val every n train epochs. By default, val is not checked, unless early_stopping is True.

max_epochs : intint (default: 400)

Stop training once this number of epochs is reached.

default_root_dir : str | NoneOptional[str] (default: None)

Default path for logs and weights when no logger/ckpt_callback passed. Defaults to scvi.settings.logging_dir. Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’

checkpoint_callback : boolbool (default: False)

If True, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in callbacks.

num_sanity_val_steps : intint (default: 0)

Sanity check runs n validation batches before starting the training routine. Set it to -1 to run all batches in all validation dataloaders.

weights_summary : {‘top’, ‘full’} | NoneOptional[Literal[‘top’, ‘full’]] (default: None)

Prints a summary of the weights when training begins.

early_stopping : boolbool (default: False)

Whether to perform early stopping with respect to the validation set. This automatically adds a EarlyStopping instance. A custom instance can be passed by using the callbacks argument and setting this to False.

early_stopping_monitor : {‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’}Literal[‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’] (default: 'elbo_validation')

Metric logged during validation set epoch. The available metrics will depend on the training plan class used. We list the most common options here in the typing.

early_stopping_min_delta : floatfloat (default: 0.0)

Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.

early_stopping_patience : intint (default: 45)

Number of validation epochs with no improvement after which training will be stopped.

early_stopping_mode : {‘min’, ‘max’}Literal[‘min’, ‘max’] (default: 'min')

In ‘min’ mode, training will stop when the quantity monitored has stopped decreasing and in ‘max’ mode it will stop when the quantity monitored has stopped increasing.

progress_bar_refresh_rate : intint (default: 1)

How often to refresh progress bar (in steps). Value 0 disables progress bar.

simple_progress_bar : boolbool (default: True)

Use custom scvi-tools simple progress bar (per epoch rather than per batch)

logger : pytorch_lightning.loggers.base.LightningLoggerBase

A valid pytorch lightning logger. Defaults to a simple dictionary logger. If True, defaults to the default pytorch lightning logger.

log_every_n_steps : intint (default: 10)

How often to log within steps. This does not affect epoch-level logging.

replace_sampler_ddp : boolbool (default: False)

Explicitly enables or disables sampler replacement. If True, by default it will add shuffle=True for train sampler and shuffle=False for val/test sampler. If you want to customize it, you can set replace_sampler_ddp=False and add your own distributed sampler.


Other keyword args for Trainer



fit(*args, **kwargs)

Runs the full optimization routine.