scvi.model.base.UnsupervisedTrainingMixin

scvi.model.base.UnsupervisedTrainingMixin#

class scvi.model.base.UnsupervisedTrainingMixin[source]#

General purpose unsupervised train method.

Methods table#

train([max_epochs, accelerator, devices, ...])

Train the model.

Methods#

UnsupervisedTrainingMixin.train(max_epochs=None, accelerator='auto', devices='auto', train_size=0.9, validation_size=None, shuffle_set_split=True, load_sparse_tensor=False, batch_size=128, early_stopping=False, datasplitter_kwargs=None, plan_kwargs=None, datamodule=None, **trainer_kwargs)[source]#

Train the model.

Parameters:
  • max_epochs (int | None (default: None)) – The maximum number of epochs to train the model. The actual number of epochs may be less if early stopping is enabled. If None, defaults to a heuristic based on get_max_epochs_heuristic(). Must be passed in if datamodule is passed in, and it does not have an n_obs attribute.

  • accelerator (str (default: 'auto')) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.

  • devices (int | list[int] | str (default: 'auto')) – The devices to use. Can be set to a non-negative index (int or str), a sequence of device indices (list or comma-separated str), the value -1 to indicate all available devices, or “auto” for automatic selection based on the chosen accelerator. If set to “auto” and accelerator is not determined to be “cpu”, then devices will be set to the first available device.

  • train_size (float (default: 0.9)) – Size of training set in the range [0.0, 1.0]. Passed into DataSplitter. Not used if datamodule is passed in.

  • validation_size (float | None (default: None)) – Size of the test set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to a test set. Passed into DataSplitter. Not used if datamodule is passed in.

  • shuffle_set_split (bool (default: True)) – Whether to shuffle indices before splitting. If False, the val, train, and test set are split in the sequential order of the data according to validation_size and train_size percentages. Passed into DataSplitter. Not used if datamodule is passed in.

  • load_sparse_tensor (bool (default: False)) – EXPERIMENTAL If True, loads data with sparse CSR or CSC layout as a Tensor with the same layout. Can lead to speedups in data transfers to GPUs, depending on the sparsity of the data. Passed into DataSplitter. Not used if datamodule is passed in.

  • batch_size (int (default: 128)) – Minibatch size to use during training. Passed into DataSplitter. Not used if datamodule is passed in.

  • early_stopping (bool (default: False)) – Perform early stopping. Additional arguments can be passed in through **kwargs. See Trainer for further options.

  • datasplitter_kwargs (dict | None (default: None)) – Additional keyword arguments passed into DataSplitter. Values in this argument can be overwritten by arguments directly passed into this method, when appropriate. Not used if datamodule is passed in.

  • plan_kwargs (dict | None (default: None)) – Additional keyword arguments passed into TrainingPlan. Values in this argument can be overwritten by arguments directly passed into this method, when appropriate.

  • datamodule (LightningDataModule | None (default: None)) – EXPERIMENTAL A LightningDataModule instance to use for training in place of the default DataSplitter. Can only be passed in if the model was not initialized with AnnData.

  • **kwargs – Additional keyword arguments passed into Trainer.