class scvi.train.AdversarialTrainingPlan(module, *, optimizer='Adam', optimizer_creator=None, lr=0.001, weight_decay=1e-06, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', lr_min=0, adversarial_classifier=False, scale_adversarial_loss='auto', **loss_kwargs)[source]#

Bases: TrainingPlan

Train vaes with adversarial loss option to encourage latent space mixing.

  • module (BaseModuleClass) – A module instance from class BaseModuleClass.

  • optimizer (Tunable_[Literal['Adam', 'AdamW', 'Custom']]) – One of “Adam” (Adam), “AdamW” (AdamW), or “Custom”, which requires a custom optimizer creator callable to be passed via optimizer_creator.

  • optimizer_creator (Optional[Callable[[Iterable[Tensor]], Optimizer]]) – A callable taking in parameters and returning a Optimizer. This allows using any PyTorch optimizer with custom hyperparameters.

  • lr (Tunable_[float]) – Learning rate used for optimization, when optimizer_creator is None.

  • weight_decay (Tunable_[float]) – Weight decay used in optimization, when optimizer_creator is None.

  • eps – eps used for optimization, when optimizer_creator is None.

  • n_steps_kl_warmup (Tunable_[int]) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

  • n_epochs_kl_warmup (Tunable_[int]) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

  • reduce_lr_on_plateau (Tunable_[bool]) – Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.

  • lr_factor (Tunable_[float]) – Factor to reduce learning rate.

  • lr_patience (Tunable_[int]) – Number of epochs with no improvement after which learning rate will be reduced.

  • lr_threshold (Tunable_[float]) – Threshold for measuring the new optimum.

  • lr_scheduler_metric (Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation']) – Which metric to track for learning rate reduction.

  • lr_min (float) – Minimum learning rate allowed

  • adversarial_classifier (Union[bool, Classifier]) – Whether to use adversarial classifier in the latent space

  • scale_adversarial_loss (Union[float, Literal['auto']]) – Scaling factor on the adversarial components of the loss. By default, adversarial loss is scaled from 1 to 0 following opposite of kl warmup.

  • **loss_kwargs – Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.

Attributes table#

Methods table#


Configure optimizers for adversarial training.

loss_adversarial_classifier(z, batch_index)

Loss for adversarial classifier.

training_step(batch, batch_idx[, optimizer_idx])

Training step for adversarial training.


training bool#




Configure optimizers for adversarial training.


AdversarialTrainingPlan.loss_adversarial_classifier(z, batch_index, predict_true_class=True)[source]#

Loss for adversarial classifier.


AdversarialTrainingPlan.training_step(batch, batch_idx, optimizer_idx=0)[source]#

Training step for adversarial training.