scvi.train.TrainRunner#
- class scvi.train.TrainRunner(model, training_plan, data_splitter, max_epochs, accelerator='auto', devices='auto', trainer_config=None, **trainer_kwargs)[source]#
Bases:
objectTrainRunner calls Trainer.fit() and handles pre and post training procedures.
- Parameters:
model (
BaseModelClass) – model to traintraining_plan (
LightningModule) – initialized TrainingPlandata_splitter (
SemiSupervisedDataSplitter|DataSplitter) – initializedSemiSupervisedDataSplitterorDataSplittermax_epochs (
int) – max_epochs to train foraccelerator (
str(default:'auto')) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.devices (
int|list[int] |str(default:'auto')) – The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or “auto” for automatic selection based on the chosen accelerator.trainer_config (
Mapping[str,Any] |KwargsConfig|None(default:None)) – Configuration forTrainer. Values here are merged withtrainer_kwargs; explicitly passedtrainer_kwargstake precedence.trainer_kwargs – Extra kwargs for
Trainer
Examples
>>> # The following code should be within a subclass of BaseModelClass >>> data_splitter = DataSplitter(self.adata) >>> training_plan = TrainingPlan(self.module, len(data_splitter.train_idx)) >>> runner = TrainRunner( >>> self, >>> training_plan=trianing_plan, >>> data_splitter=data_splitter, >>> max_epochs=max_epochs) >>> runner()