scvi.train.TrainRunner#

class scvi.train.TrainRunner(model, training_plan, data_splitter, max_epochs, use_gpu=None, accelerator='auto', devices='auto', **trainer_kwargs)[source]#

Bases: object

TrainRunner calls Trainer.fit() and handles pre and post training procedures.

Parameters
  • model (BaseModelClass) – model to train

  • training_plan (LightningModule) – initialized TrainingPlan

  • data_splitter (Union[SemiSupervisedDataSplitter, DataSplitter]) – initialized SemiSupervisedDataSplitter or DataSplitter

  • max_epochs (int) – max_epochs to train for

  • use_gpu (Union[str, int, bool, None] (default: None)) – Use default GPU if available (if True), or index of GPU to use (if int), or name of GPU (if str, e.g., ‘cuda:0’), or use CPU (if False). Passing in use_gpu != None will override accelerator and devices arguments. This argument is deprecated in v1.0 and will be removed in v1.1. Please use accelerator and devices instead.

  • accelerator (str (default: 'auto')) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.

  • devices (Union[int, List[int], str] (default: 'auto')) – The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or “auto” for automatic selection based on the chosen accelerator.

  • trainer_kwargs – Extra kwargs for Trainer

Examples

>>> # Following code should be within a subclass of BaseModelClass
>>> data_splitter = DataSplitter(self.adata)
>>> training_plan = TrainingPlan(self.module, len(data_splitter.train_idx))
>>> runner = TrainRunner(
>>>     self,
>>>     training_plan=trianing_plan,
>>>     data_splitter=data_splitter,
>>>     max_epochs=max_epochs)
>>> runner()

Methods table#

Methods#