class scvi.train.TrainRunner(model, training_plan, data_splitter, max_epochs, accelerator='auto', devices='auto', **trainer_kwargs)[source]#

Bases: object

TrainRunner calls and handles pre and post training procedures.

  • model (BaseModelClass) – model to train

  • training_plan (LightningModule) – initialized TrainingPlan

  • data_splitter (Union[SemiSupervisedDataSplitter, DataSplitter]) – initialized SemiSupervisedDataSplitter or DataSplitter

  • max_epochs (int) – max_epochs to train for

  • accelerator (str (default: 'auto')) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.

  • devices (Union[int, list[int], str] (default: 'auto')) – The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or “auto” for automatic selection based on the chosen accelerator.

  • trainer_kwargs – Extra kwargs for Trainer


>>> # Following code should be within a subclass of BaseModelClass
>>> data_splitter = DataSplitter(self.adata)
>>> training_plan = TrainingPlan(self.module, len(data_splitter.train_idx))
>>> runner = TrainRunner(
>>>     self,
>>>     training_plan=trianing_plan,
>>>     data_splitter=data_splitter,
>>>     max_epochs=max_epochs)
>>> runner()

Methods table#