TOTALVI.train(max_epochs=400, lr=0.004, use_gpu=None, train_size=0.9, validation_size=None, batch_size=256, early_stopping=True, check_val_every_n_epoch=None, reduce_lr_on_plateau=True, n_steps_kl_warmup=None, n_epochs_kl_warmup=None, adversarial_classifier=None, plan_kwargs=None, **kwargs)[source]

Trains the model using amortized variational inference.

max_epochs : int | NoneOptional[int] (default: 400)

Number of passes through the dataset.

lr : floatfloat (default: 0.004)

Learning rate for optimization.

use_gpu : str | int | bool | NoneUnion[str, int, bool, None] (default: None)

Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., ‘cuda:0’), or use CPU (if False).

train_size : floatfloat (default: 0.9)

Size of training set in the range [0.0, 1.0].

validation_size : float | NoneOptional[float] (default: None)

Size of the test set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to a test set.

batch_size : intint (default: 256)

Minibatch size to use during training.

early_stopping : boolbool (default: True)

Whether to perform early stopping with respect to the validation set.

check_val_every_n_epoch : int | NoneOptional[int] (default: None)

Check val every n train epochs. By default, val is not checked, unless early_stopping is True or reduce_lr_on_plateau is True. If either of the latter conditions are met, val is checked every epoch.

reduce_lr_on_plateau : boolbool (default: True)

Reduce learning rate on plateau of validation metric (default is ELBO).

n_steps_kl_warmup : int | NoneOptional[int] (default: None)

Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None. If None, defaults to floor(0.75 * adata.n_obs).

n_epochs_kl_warmup : int | NoneOptional[int] (default: None)

Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

adversarial_classifier : bool | NoneOptional[bool] (default: None)

Whether to use adversarial classifier in the latent space. This helps mixing when there are missing proteins in any of the batches. Defaults to True is missing proteins are detected.

plan_kwargs : dict | NoneOptional[dict] (default: None)

Keyword args for AdversarialTrainingPlan. Keyword arguments passed to train() will overwrite values present in plan_kwargs, when appropriate.


Other keyword args for Trainer.