scvi.train.Trainer#

class scvi.train.Trainer(gpus=1, benchmark=True, flush_logs_every_n_steps=inf, check_val_every_n_epoch=None, max_epochs=400, default_root_dir=None, enable_checkpointing=False, num_sanity_val_steps=0, enable_model_summary=False, early_stopping=False, early_stopping_monitor='elbo_validation', early_stopping_min_delta=0.0, early_stopping_patience=45, early_stopping_mode='min', enable_progress_bar=False, progress_bar_refresh_rate=1, simple_progress_bar=True, logger=None, log_every_n_steps=10, replace_sampler_ddp=False, **kwargs)[source]#

Bases: Trainer

Lightweight wrapper of Pytorch Lightning Trainer.

Appropriate defaults are set for scvi-tools models, as well as callbacks like EarlyStopping, with parameters accessible here.

Parameters:
gpus : int | strUnion[int, str] (default: 1)

Number of gpus to train on (int) or which GPUs to train on (list or str) applied per node

benchmark : bool (default: True)

If true enables cudnn.benchmark, which improves speed when inputs are fixed size

check_val_every_n_epoch : int | NoneOptional[int] (default: None)

Check val every n train epochs. By default, val is not checked, unless early_stopping is True.

max_epochs : int (default: 400)

Stop training once this number of epochs is reached.

default_root_dir : str | NoneOptional[str] (default: None)

Default path for logs and weights when no logger/ckpt_callback passed. Defaults to scvi.settings.logging_dir. Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’

enable_checkpointing : bool (default: False)

If True, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in callbacks.

num_sanity_val_steps : int (default: 0)

Sanity check runs n validation batches before starting the training routine. Set it to -1 to run all batches in all validation dataloaders.

enable_model_summary : bool (default: False)

Whether to enable or disable the model summarization.

early_stopping : bool (default: False)

Whether to perform early stopping with respect to the validation set. This automatically adds a EarlyStopping instance. A custom instance can be passed by using the callbacks argument and setting this to False.

early_stopping_monitor : {‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’}Literal[‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’] (default: 'elbo_validation')

Metric logged during validation set epoch. The available metrics will depend on the training plan class used. We list the most common options here in the typing.

early_stopping_min_delta : float (default: 0.0)

Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.

early_stopping_patience : int (default: 45)

Number of validation epochs with no improvement after which training will be stopped.

early_stopping_mode : {‘min’, ‘max’}Literal[‘min’, ‘max’] (default: 'min')

In ‘min’ mode, training will stop when the quantity monitored has stopped decreasing and in ‘max’ mode it will stop when the quantity monitored has stopped increasing.

enable_progress_bar : bool (default: False)

Whether to enable or disable the default PyTorch Lightning progress bar.

progress_bar_refresh_rate : int (default: 1)

How often to refresh progress bar (in steps). Value 0 disables progress bar.

simple_progress_bar : bool (default: True)

Use custom scvi-tools simple progress bar (per epoch rather than per batch)

logger : LightningLoggerBase | None | boolUnion[LightningLoggerBase, None, bool] (default: None)

A valid pytorch lightning logger. Defaults to a simple dictionary logger. If True, defaults to the default pytorch lightning logger.

log_every_n_steps : int

How often to log within steps. This does not affect epoch-level logging.

replace_sampler_ddp : bool (default: False)

Explicitly enables or disables sampler replacement. If True, by default it will add shuffle=True for train sampler and shuffle=False for val/test sampler. If you want to customize it, you can set replace_sampler_ddp=False and add your own distributed sampler.

**kwargs

Other keyword args for Trainer

Attributes table#

Methods table#

fit(*args, **kwargs)

Runs the full optimization routine.

Attributes#

accelerator#

Trainer.accelerator: Accelerator#
Return type:

Accelerator

amp_backend#

Trainer.amp_backend#
Return type:

str | NoneOptional[str]

callback_metrics#

Trainer.callback_metrics#
Return type:

dict

callbacks#

Trainer.callbacks: List[Callback] = []#

checkpoint_callback#

Trainer.checkpoint_callback#

The first ModelCheckpoint callback in the Trainer.callbacks list, or None if it doesn’t exist.

Return type:

ModelCheckpoint | NoneOptional[ModelCheckpoint]

checkpoint_callbacks#

Trainer.checkpoint_callbacks#

A list of all instances of ModelCheckpoint found in the Trainer.callbacks list.

Return type:

List[ModelCheckpoint]

current_epoch#

Trainer.current_epoch: int#
Return type:

int

data_parallel#

Trainer.data_parallel#
Return type:

bool

data_parallel_device_ids#

Trainer.data_parallel_device_ids#
Return type:

List[int] | NoneOptional[List[int]]

default_root_dir#

Trainer.default_root_dir#

The default location to save artifacts of loggers, checkpoints etc.

It is used as a fallback if logger or checkpoint callback do not define specific save paths.

Return type:

str

devices#

Trainer.devices#
Return type:

int | str | List[int] | NoneUnion[int, str, List[int], None]

disable_validation#

Trainer.disable_validation#

Check if validation is disabled during training.

Return type:

bool

distributed_sampler_kwargs#

Trainer.distributed_sampler_kwargs: dict#
Return type:

dict | NoneOptional[dict]

early_stopping_callback#

Trainer.early_stopping_callback#

The first EarlyStopping callback in the Trainer.callbacks list, or None if it doesn’t exist.

Return type:

EarlyStopping | NoneOptional[EarlyStopping]

early_stopping_callbacks#

Trainer.early_stopping_callbacks#

A list of all instances of EarlyStopping found in the Trainer.callbacks list.

Return type:

List[EarlyStopping]

enable_validation#

Trainer.enable_validation#

Check if we should run validation during training.

Return type:

bool

evaluating#

Trainer.evaluating#
Return type:

bool

fit_loop#

Trainer.fit_loop#
Return type:

FitLoop

global_rank#

Trainer.global_rank#
Return type:

int

global_step#

Trainer.global_step#
Return type:

int

gpus#

Trainer.gpus#
Return type:

int | str | List[int] | NoneUnion[int, str, List[int], None]

interrupted#

Trainer.interrupted#
Return type:

bool

ipus#

Trainer.ipus#
Return type:

int

is_global_zero#

Trainer.is_global_zero#
Return type:

bool

is_last_batch#

Trainer.is_last_batch#
Return type:

bool

lightning_module#

Trainer.lightning_module: pl.LightningModule#
Return type:

LightningModule

lightning_optimizers#

Trainer.lightning_optimizers#
Return type:

List[LightningOptimizer]

local_rank#

Trainer.local_rank#
Return type:

int

log_dir#

Trainer.log_dir#
Return type:

str | NoneOptional[str]

logged_metrics#

Trainer.logged_metrics#
Return type:

dict

lr_schedulers#

Trainer.lr_schedulers#
Return type:

List[Union[_LRScheduler, ReduceLROnPlateau]]

max_epochs#

Trainer.max_epochs#
Return type:

int

max_steps#

Trainer.max_steps#
Return type:

int

min_epochs#

Trainer.min_epochs#
Return type:

int | NoneOptional[int]

min_steps#

Trainer.min_steps#
Return type:

int | NoneOptional[int]

model#

Trainer.model#

The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel.

To access the pure LightningModule, use lightning_module() instead.

Return type:

Module

node_rank#

Trainer.node_rank#
Return type:

int

num_gpus#

Trainer.num_gpus#
Return type:

int

num_nodes#

Trainer.num_nodes#
Return type:

int

num_processes#

Trainer.num_processes#
Return type:

int

optimizer_frequencies#

Trainer.optimizer_frequencies#
Return type:

list

optimizers#

Trainer.optimizers#
Return type:

List[Optimizer]

precision#

Trainer.precision#
Return type:

str | intUnion[str, int]

precision_plugin#

Trainer.precision_plugin#
Return type:

PrecisionPlugin

predict_loop#

Trainer.predict_loop#
Return type:

PredictionLoop

predicting#

Trainer.predicting#
Return type:

bool

prediction_writer_callbacks#

Trainer.prediction_writer_callbacks#

A list of all instances of BasePredictionWriter found in the Trainer.callbacks list.

Return type:

List[BasePredictionWriter]

progress_bar_callback#

Trainer.progress_bar_callback#
Return type:

ProgressBarBase | NoneOptional[ProgressBarBase]

progress_bar_dict#

Trainer.progress_bar_dict#

Read-only for progress bar metrics.

Return type:

dict

progress_bar_metrics#

Trainer.progress_bar_metrics#
Return type:

dict

resume_from_checkpoint#

Trainer.resume_from_checkpoint#
Return type:

str | Path | NoneUnion[str, Path, None]

root_gpu#

Trainer.root_gpu#
Return type:

int | NoneOptional[int]

sanity_checking#

Trainer.sanity_checking#
Return type:

bool

scaler#

Trainer.scaler#

should_rank_save_checkpoint#

Trainer.should_rank_save_checkpoint#
Return type:

bool

slurm_job_id#

Trainer.slurm_job_id#
Return type:

int | NoneOptional[int]

terminate_on_nan#

Trainer.terminate_on_nan#
Return type:

bool

test_loop#

Trainer.test_loop#
Return type:

EvaluationLoop

testing#

Trainer.testing#
Return type:

bool

tpu_cores#

Trainer.tpu_cores#
Return type:

int

train_loop#

Trainer.train_loop#
Return type:

FitLoop

training#

Trainer.training#
Return type:

bool

training_type_plugin#

Trainer.training_type_plugin#
Return type:

TrainingTypePlugin

tuning#

Trainer.tuning#
Return type:

bool

use_amp#

Trainer.use_amp#
Return type:

bool

validate_loop#

Trainer.validate_loop#
Return type:

EvaluationLoop

validating#

Trainer.validating#
Return type:

bool

weights_save_path#

Trainer.weights_save_path#

The default root location to save weights (checkpoints), e.g., when the ModelCheckpoint does not define a file path.

Return type:

str

weights_summary#

Trainer.weights_summary#
Return type:

str | NoneOptional[str]

world_size#

Trainer.world_size#
Return type:

int

val_check_interval#

Trainer.val_check_interval: float#

reload_dataloaders_every_n_epochs#

Trainer.reload_dataloaders_every_n_epochs: int#

tpu_local_core_rank#

Trainer.tpu_local_core_rank: int#

train_dataloader#

Trainer.train_dataloader: DataLoader#

limit_train_batches#

Trainer.limit_train_batches: Union[int, float]#

num_training_batches#

Trainer.num_training_batches: int#

val_check_batch#

Trainer.val_check_batch: float#

val_dataloaders#

Trainer.val_dataloaders: List[DataLoader]#

limit_val_batches#

Trainer.limit_val_batches: Union[int, float]#

num_val_batches#

Trainer.num_val_batches: List[int]#

test_dataloaders#

Trainer.test_dataloaders: List[DataLoader]#

limit_test_batches#

Trainer.limit_test_batches: Union[int, float]#

num_test_batches#

Trainer.num_test_batches: List[int]#

predict_dataloaders#

Trainer.predict_dataloaders: List[DataLoader]#

limit_predict_batches#

Trainer.limit_predict_batches: Union[int, float]#

num_predict_batches#

Trainer.num_predict_batches: List[int]#

log_every_n_steps#

Trainer.log_every_n_steps: int#

overfit_batches#

Trainer.overfit_batches: Union[int, float]#

accelerator_connector#

Trainer.accelerator_connector: AcceleratorConnector#

Methods#

fit#

Trainer.fit(*args, **kwargs)[source]#

Runs the full optimization routine.

Parameters:
model

Model to fit.

train_dataloaders

A collection of torch.utils.data.DataLoader or a LightningDataModule specifying training samples. In the case of multiple dataloaders, please see this page.

val_dataloaders

A torch.utils.data.DataLoader or a sequence of them specifying validation samples.

ckpt_path

Path/URL of the checkpoint from which training is resumed. If there is no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch.

datamodule

An instance of LightningDataModule.