scvi.train.Trainer#
- class scvi.train.Trainer(gpus=1, benchmark=True, flush_logs_every_n_steps=inf, check_val_every_n_epoch=None, max_epochs=400, default_root_dir=None, enable_checkpointing=False, num_sanity_val_steps=0, enable_model_summary=False, early_stopping=False, early_stopping_monitor='elbo_validation', early_stopping_min_delta=0.0, early_stopping_patience=45, early_stopping_mode='min', enable_progress_bar=True, progress_bar_refresh_rate=1, simple_progress_bar=True, logger=None, log_every_n_steps=10, replace_sampler_ddp=False, **kwargs)[source]#
Bases:
Trainer
Lightweight wrapper of Pytorch Lightning Trainer.
Appropriate defaults are set for scvi-tools models, as well as callbacks like EarlyStopping, with parameters accessible here.
- Parameters:
gpus (
Union
[int
,str
] (default:1
)) – Number of gpus to train on (int) or which GPUs to train on (list or str) applied per nodebenchmark (
bool
(default:True
)) – If true enables cudnn.benchmark, which improves speed when inputs are fixed sizecheck_val_every_n_epoch (
Optional
[int
] (default:None
)) – Check val every n train epochs. By default, val is not checked, unless early_stopping is True.max_epochs (
int
(default:400
)) – Stop training once this number of epochs is reached.default_root_dir (
Optional
[str
] (default:None
)) – Default path for logs and weights when no logger/ckpt_callback passed. Defaults to scvi.settings.logging_dir. Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’enable_checkpointing (
bool
(default:False
)) – If True, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in callbacks.num_sanity_val_steps (
int
(default:0
)) – Sanity check runs n validation batches before starting the training routine. Set it to -1 to run all batches in all validation dataloaders.enable_model_summary (
bool
(default:False
)) – Whether to enable or disable the model summarization.early_stopping (
bool
(default:False
)) – Whether to perform early stopping with respect to the validation set. This automatically adds aEarlyStopping
instance. A custom instance can be passed by using the callbacks argument and setting this to False.early_stopping_monitor (
Literal
[‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’] (default:'elbo_validation'
)) – Metric logged during validation set epoch. The available metrics will depend on the training plan class used. We list the most common options here in the typing.early_stopping_min_delta (
float
(default:0.0
)) – Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.early_stopping_patience (
int
(default:45
)) – Number of validation epochs with no improvement after which training will be stopped.early_stopping_mode (
Literal
[‘min’, ‘max’] (default:'min'
)) – In ‘min’ mode, training will stop when the quantity monitored has stopped decreasing and in ‘max’ mode it will stop when the quantity monitored has stopped increasing.enable_progress_bar (
bool
(default:True
)) – Whether to enable or disable the progress bar.progress_bar_refresh_rate (
int
(default:1
)) – How often to refresh progress bar (in steps). Value 0 disables progress bar.simple_progress_bar (
bool
(default:True
)) – Use custom scvi-tools simple progress bar (per epoch rather than per batch). When False, uses default PyTorch Lightning progress bar, unless enable_progress_bar is False.logger (
Union
[LightningLoggerBase
,None
,bool
] (default:None
)) – A valid pytorch lightning logger. Defaults to a simple dictionary logger. If True, defaults to the default pytorch lightning logger.log_every_n_steps (
int
(default:10
)) – How often to log within steps. This does not affect epoch-level logging.replace_sampler_ddp (
bool
(default:False
)) – Explicitly enables or disables sampler replacement. If True, by default it will add shuffle=True for train sampler and shuffle=False for val/test sampler. If you want to customize it, you can set replace_sampler_ddp=False and add your own distributed sampler.**kwargs – Other keyword args for
Trainer
Attributes table#
Methods table#
|
Fit the model. |
Attributes#
Methods#
fit