scvi.train.TrainRunner

class scvi.train.TrainRunner(model, training_plan, data_splitter, max_epochs, use_gpu=None, **trainer_kwargs)[source]

Bases: object

TrainRunner calls Trainer.fit() and handles pre and post training procedures.

Parameters
model : BaseModelClassBaseModelClass

model to train

training_plan : LightningModuleLightningModule

initialized TrainingPlan

data_splitter : SemiSupervisedDataSplitter | DataSplitterUnion[SemiSupervisedDataSplitter, DataSplitter]

initialized SemiSupervisedDataSplitter or DataSplitter

max_epochs : intint

max_epochs to train for

use_gpu : str | int | bool | NoneUnion[str, int, bool, None] (default: None)

Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., ‘cuda:0’), or use CPU (if False).

trainer_kwargs

Extra kwargs for Trainer

Examples

>>> # Following code should be within a subclass of BaseModelClass
>>> data_splitter = DataSplitter(self.adata)
>>> training_plan = TrainingPlan(self.module, len(data_splitter.train_idx))
>>> runner = TrainRunner(
>>>     self,
>>>     training_plan=trianing_plan,
>>>     data_splitter=data_splitter,
>>>     max_epochs=max_epochs)
>>> runner()

Methods