class scvi.train.AdversarialTrainingPlan(module, lr=0.001, weight_decay=1e-06, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', lr_min=0, adversarial_classifier=False, scale_adversarial_loss='auto', **loss_kwargs)[source]

Bases: scvi.train._trainingplans.TrainingPlan

Train vaes with adversarial loss option to encourage latent space mixing.

module : BaseModuleClassBaseModuleClass

A module instance from class BaseModuleClass.


Learning rate used for optimization Adam.


Weight decay used in Adam.

n_steps_kl_warmup : int | NoneOptional[int] (default: None)

Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

n_epochs_kl_warmup : int | NoneOptional[int] (default: 400)

Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

reduce_lr_on_plateau : boolbool (default: False)

Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.

lr_factor : floatfloat (default: 0.6)

Factor to reduce learning rate.

lr_patience : intint (default: 30)

Number of epochs with no improvement after which learning rate will be reduced.

lr_threshold : floatfloat (default: 0.0)

Threshold for measuring the new optimum.

lr_scheduler_metric : {‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’}Literal[‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’] (default: 'elbo_validation')

Which metric to track for learning rate reduction.

lr_min : floatfloat (default: 0)

Minimum learning rate allowed

adversarial_classifier : bool | ClassifierUnion[bool, Classifier] (default: False)

Whether to use adversarial classifier in the latent space

scale_adversarial_loss : float | {‘auto’}Union[float, Literal[‘auto’]] (default: 'auto')

Scaling factor on the adversarial components of the loss. By default, adversarial loss is scaled from 1 to 0 following opposite of kl warmup.


Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.




Choose what optimizers and learning-rate schedulers to use in your optimization.

loss_adversarial_classifier(z, batch_index)


Called at the end of the training epoch with the outputs of all training steps.

training_step(batch, batch_idx[, optimizer_idx])

Here you compute and return the training loss and some additional metrics for e.g.