scvi.model.base.PyroSviTrainMixin#
- class scvi.model.base.PyroSviTrainMixin[source]#
Mixin class for training Pyro models.
Training using minibatches and using full data (copies data to GPU only once).
Methods table#
|
Train the model. |
Methods#
train#
- PyroSviTrainMixin.train(max_epochs=None, use_gpu=None, train_size=0.9, validation_size=None, batch_size=128, early_stopping=False, lr=None, training_plan=<class 'scvi.train._trainingplans.PyroTrainingPlan'>, plan_kwargs=None, **trainer_kwargs)[source]#
Train the model.
- Parameters
- max_epochs :
int
|None
Optional
[int
] (default:None
) Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400])
- use_gpu :
str
|int
|bool
|None
Union
[str
,int
,bool
,None
] (default:None
) Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., ‘cuda:0’), or use CPU (if False).
- train_size :
float
(default:0.9
) Size of training set in the range [0.0, 1.0].
- validation_size :
float
|None
Optional
[float
] (default:None
) Size of the test set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to a test set.
- batch_size :
int
(default:128
) Minibatch size to use during training. If None, no minibatching occurs and all data is copied to device (e.g., GPU).
- early_stopping :
bool
(default:False
) Perform early stopping. Additional arguments can be passed in **kwargs. See
Trainer
for further options.- lr :
float
|None
Optional
[float
] (default:None
) Optimiser learning rate (default optimiser is
ClippedAdam
). Specifying optimiser via plan_kwargs overrides this choice of lr.- training_plan :
PyroTrainingPlan
(default:<class 'scvi.train._trainingplans.PyroTrainingPlan'>
) Training plan
PyroTrainingPlan
.- plan_kwargs :
dict
|None
Optional
[dict
] (default:None
) Keyword args for
PyroTrainingPlan
. Keyword arguments passed to train() will overwrite values present in plan_kwargs, when appropriate.- **trainer_kwargs
Other keyword args for
Trainer
.
- max_epochs :