scvi.model.base.UnsupervisedTrainingMixin#
Methods table#
|
Train the model. |
Methods#
- UnsupervisedTrainingMixin.train(max_epochs=None, accelerator='auto', devices='auto', train_size=0.9, validation_size=None, shuffle_set_split=True, load_sparse_tensor=False, batch_size=128, early_stopping=False, datasplitter_kwargs=None, plan_kwargs=None, datamodule=None, **trainer_kwargs)[source]#
Train the model.
- Parameters:
max_epochs (
int
|None
(default:None
)) – The maximum number of epochs to train the model. The actual number of epochs may be less if early stopping is enabled. IfNone
, defaults to a heuristic based onget_max_epochs_heuristic()
. Must be passed in ifdatamodule
is passed in, and it does not have ann_obs
attribute.accelerator (
str
(default:'auto'
)) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.devices (
int
|list
[int
] |str
(default:'auto'
)) – The devices to use. Can be set to a non-negative index (int or str), a sequence of device indices (list or comma-separated str), the value -1 to indicate all available devices, or “auto” for automatic selection based on the chosen accelerator. If set to “auto” and accelerator is not determined to be “cpu”, then devices will be set to the first available device.train_size (
float
(default:0.9
)) – Size of training set in the range[0.0, 1.0]
. Passed intoDataSplitter
. Not used ifdatamodule
is passed in.validation_size (
float
|None
(default:None
)) – Size of the test set. IfNone
, defaults to1 - train_size
. Iftrain_size + validation_size < 1
, the remaining cells belong to a test set. Passed intoDataSplitter
. Not used ifdatamodule
is passed in.shuffle_set_split (
bool
(default:True
)) – Whether to shuffle indices before splitting. IfFalse
, the val, train, and test set are split in the sequential order of the data according tovalidation_size
andtrain_size
percentages. Passed intoDataSplitter
. Not used ifdatamodule
is passed in.load_sparse_tensor (
bool
(default:False
)) –EXPERIMENTAL
IfTrue
, loads data with sparse CSR or CSC layout as aTensor
with the same layout. Can lead to speedups in data transfers to GPUs, depending on the sparsity of the data. Passed intoDataSplitter
. Not used ifdatamodule
is passed in.batch_size (
int
(default:128
)) – Minibatch size to use during training. Passed intoDataSplitter
. Not used ifdatamodule
is passed in.early_stopping (
bool
(default:False
)) – Perform early stopping. Additional arguments can be passed in through**kwargs
. SeeTrainer
for further options.datasplitter_kwargs (
dict
|None
(default:None
)) – Additional keyword arguments passed intoDataSplitter
. Values in this argument can be overwritten by arguments directly passed into this method, when appropriate. Not used ifdatamodule
is passed in.plan_kwargs (
dict
|None
(default:None
)) – Additional keyword arguments passed intoTrainingPlan
. Values in this argument can be overwritten by arguments directly passed into this method, when appropriate.datamodule (
LightningDataModule
|None
(default:None
)) –EXPERIMENTAL
ALightningDataModule
instance to use for training in place of the defaultDataSplitter
. Can only be passed in if the model was not initialized withAnnData
.**kwargs – Additional keyword arguments passed into
Trainer
.