scvi.train.TrainRunner#

class scvi.train.TrainRunner(model, training_plan, data_splitter, max_epochs, use_gpu=None, **trainer_kwargs)[source]#

Bases: object

TrainRunner calls Trainer.fit() and handles pre and post training procedures.

Parameters:

Examples

>>> # Following code should be within a subclass of BaseModelClass
>>> data_splitter = DataSplitter(self.adata)
>>> training_plan = TrainingPlan(self.module, len(data_splitter.train_idx))
>>> runner = TrainRunner(
>>>     self,
>>>     training_plan=trianing_plan,
>>>     data_splitter=data_splitter,
>>>     max_epochs=max_epochs)
>>> runner()

Methods table#

Methods#