class scvi.train.TrainRunner(model, training_plan, data_splitter, max_epochs, use_gpu=None, **trainer_kwargs)[source]#

Bases: object

TrainRunner calls and handles pre and post training procedures.

model : BaseModelClass

model to train

training_plan : LightningModule

initialized TrainingPlan

data_splitter : SemiSupervisedDataSplitter | DataSplitterUnion[SemiSupervisedDataSplitter, DataSplitter]

initialized SemiSupervisedDataSplitter or DataSplitter

max_epochs : int

max_epochs to train for

use_gpu : str | int | bool | NoneUnion[str, int, bool, None] (default: None)

Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., ‘cuda:0’), or use CPU (if False).


Extra kwargs for Trainer


>>> # Following code should be within a subclass of BaseModelClass
>>> data_splitter = DataSplitter(self.adata)
>>> training_plan = TrainingPlan(self.module, len(data_splitter.train_idx))
>>> runner = TrainRunner(
>>>     self,
>>>     training_plan=trianing_plan,
>>>     data_splitter=data_splitter,
>>>     max_epochs=max_epochs)
>>> runner()

Methods table#