scvi.model.base.UnsupervisedTrainingMixin#
Methods table#
|
Train the model. |
Methods#
- UnsupervisedTrainingMixin.train(max_epochs=None, accelerator='auto', devices='auto', train_size=None, validation_size=None, shuffle_set_split=True, load_sparse_tensor=False, batch_size=128, early_stopping=False, datasplitter_kwargs=None, plan_config=None, plan_kwargs=None, datamodule=None, trainer_config=None, **trainer_kwargs)[source]#
Train the model.
- Parameters:
max_epochs (
int|None(default:None)) – The maximum number of epochs to train the model. The actual number of epochs may be less if early stopping is enabled. IfNone, defaults to a heuristic based onget_max_epochs_heuristic(). Must be passed in ifdatamoduleis passed in, and it does not have ann_obsattribute.accelerator (
str(default:'auto')) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.devices (
int|list[int] |str(default:'auto')) – The devices to use. Can be set to a non-negative index (int or str), a sequence of device indices (list or comma-separated str), the value -1 to indicate all available devices, or “auto” for automatic selection based on the chosen accelerator. If set to “auto” and accelerator is not determined to be “cpu”, then devices will be set to the first available device.train_size (
float|None(default:None)) – Float, or None. Size of training set in the range[0.0, 1.0]. The default is None, which is practically 0.9 and potentially adding a small last batch to validation cells. Passed intoDataSplitter. Not used ifdatamoduleis passed in.validation_size (
float|None(default:None)) – Size of the test set. IfNone, defaults to1 - train_size. Iftrain_size + validation_size < 1, the remaining cells belong to a test set. Passed intoDataSplitter. Not used ifdatamoduleis passed in.shuffle_set_split (
bool(default:True)) – Whether to shuffle indices before splitting. IfFalse, the val, train, and test set are split in the sequential order of the data according tovalidation_sizeandtrain_sizepercentages. Passed intoDataSplitter. Not used ifdatamoduleis passed in.load_sparse_tensor (
bool(default:False)) –EXPERIMENTALIfTrue, loads data with sparse CSR or CSC layout as aTensorwith the same layout. Can lead to speedups in data transfers to GPUs, depending on the sparsity of the data. Passed intoDataSplitter. Not used ifdatamoduleis passed in.batch_size (
int(default:128)) – Minibatch size to use during training. Passed intoDataSplitter. Not used ifdatamoduleis passed in.early_stopping (
bool(default:False)) – Perform early stopping. Additional arguments can be passed in through**kwargs. SeeTrainerfor further options.datasplitter_kwargs (
dict|None(default:None)) – Additional keyword arguments passed intoDataSplitter. Values in this argument can be overwritten by arguments directly passed into this method, when appropriate. Not used ifdatamoduleis passed in.plan_config (
Mapping[str,Any] |KwargsConfig|None(default:None)) – Configuration object or mapping used to buildTrainingPlan. Values inplan_kwargsand explicit arguments take precedence.plan_kwargs (
Mapping[str,Any] |KwargsConfig|None(default:None)) – Additional keyword arguments passed intoTrainingPlan. Values in this argument can be overwritten by arguments directly passed into this method, when appropriate.datamodule (
LightningDataModule|None(default:None)) –EXPERIMENTALALightningDataModuleinstance to use for training in place of the defaultDataSplitter. Can only be passed in if the model was not initialized withAnnData.trainer_config (
Mapping[str,Any] |KwargsConfig|None(default:None)) – Configuration object or mapping used to buildTrainer. Values intrainer_kwargsand explicit arguments take precedence.**kwargs – Additional keyword arguments passed into
Trainer.