scvi.train.TrainRunner#
- class scvi.train.TrainRunner(model, training_plan, data_splitter, max_epochs, use_gpu=None, **trainer_kwargs)[source]#
Bases:
object
TrainRunner calls Trainer.fit() and handles pre and post training procedures.
- Parameters:
- model :
BaseModelClass
model to train
- training_plan :
LightningModule
initialized TrainingPlan
- data_splitter :
SemiSupervisedDataSplitter
|DataSplitter
Union
[SemiSupervisedDataSplitter
,DataSplitter
] initialized
SemiSupervisedDataSplitter
orDataSplitter
- max_epochs :
int
max_epochs to train for
- use_gpu :
str
|int
|bool
|None
Union
[str
,int
,bool
,None
] (default:None
) Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., ‘cuda:0’), or use CPU (if False).
- trainer_kwargs
Extra kwargs for
Trainer
- model :
Examples
>>> # Following code should be within a subclass of BaseModelClass >>> data_splitter = DataSplitter(self.adata) >>> training_plan = TrainingPlan(self.module, len(data_splitter.train_idx)) >>> runner = TrainRunner( >>> self, >>> training_plan=trianing_plan, >>> data_splitter=data_splitter, >>> max_epochs=max_epochs) >>> runner()