scvi.model.base.PyroSviTrainMixin#
- class scvi.model.base.PyroSviTrainMixin[source]#
Mixin class for training Pyro models.
Training using minibatches and using full data (copies data to GPU only once).
Methods table#
|
Train the model. |
Methods#
- PyroSviTrainMixin.train(max_epochs=None, accelerator='auto', device='auto', train_size=0.9, validation_size=None, shuffle_set_split=True, batch_size=128, early_stopping=False, lr=None, training_plan=None, datasplitter_kwargs=None, plan_kwargs=None, **trainer_kwargs)[source]#
Train the model.
- Parameters:
max_epochs (
int
|None
(default:None
)) – Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400])accelerator (
str
(default:'auto'
)) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”) as well as custom accelerator instances.device (
int
|str
(default:'auto'
)) – The device to use. Can be set to a non-negative index (int or str) or “auto” for automatic selection based on the chosen accelerator. If set to “auto” and accelerator is not determined to be “cpu”, then device will be set to the first available device.train_size (
float
(default:0.9
)) – Size of training set in the range [0.0, 1.0].validation_size (
float
|None
(default:None
)) – Size of the test set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to a test set.shuffle_set_split (
bool
(default:True
)) – Whether to shuffle indices before splitting. If False, the val, train, and test set are split in the sequential order of the data according to validation_size and train_size percentages.batch_size (
int
(default:128
)) – Minibatch size to use during training. If None, no minibatching occurs and all data is copied to device (e.g., GPU).early_stopping (
bool
(default:False
)) – Perform early stopping. Additional arguments can be passed in **kwargs. SeeTrainer
for further options.lr (
float
|None
(default:None
)) – Optimiser learning rate (default optimiser isClippedAdam
). Specifying optimiser via plan_kwargs overrides this choice of lr.training_plan (
PyroTrainingPlan
|None
(default:None
)) – Training planPyroTrainingPlan
.datasplitter_kwargs (
dict
|None
(default:None
)) – Additional keyword arguments passed intoDataSplitter
.plan_kwargs (
dict
|None
(default:None
)) – Keyword args forPyroTrainingPlan
. Keyword arguments passed to train() will overwrite values present in plan_kwargs, when appropriate.**trainer_kwargs – Other keyword args for
Trainer
.