scvi.train.SemiSupervisedAdversarialTrainingPlan#
- class scvi.train.SemiSupervisedAdversarialTrainingPlan(module, n_classes, *, key_adversarial='batch', optimizer='Adam', optimizer_creator=None, classification_ratio=50, lr=0.001, weight_decay=1e-06, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', lr_min=0, adversarial_classifier=False, scale_adversarial_loss='auto', **loss_kwargs)[source]#
Bases:
SemiSupervisedTrainingPlanLightning module task for SemiSupervised Training with Adversarial Loss.
- Parameters:
module (
BaseModuleClass) – A module instance from classBaseModuleClass.optimizer (
Literal['Adam','AdamW','Custom'] (default:'Adam')) – One of “Adam” (Adam), “AdamW” (AdamW), or “Custom”, which requires a custom optimizer creator callable to be passed via optimizer_creator.optimizer_creator (
Callable[[Iterable[Tensor]],Optimizer] |None(default:None)) – A callable taking in parameters and returning aOptimizer. This allows using any PyTorch optimizer with custom hyperparameters.n_classes (
int) – The number of classes in the labeled dataset.classification_ratio (
int(default:50)) – Weight of the classification_loss in loss functionlr (
float(default:0.001)) – Learning rate used for optimizationAdam.weight_decay (
float(default:1e-06)) – Weight decay used inAdam.eps – eps used for optimization, when optimizer_creator is None.
n_steps_kl_warmup (
int|None(default:None)) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.n_epochs_kl_warmup (
int|None(default:400)) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.reduce_lr_on_plateau (
bool(default:False)) – Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.lr_factor (
float(default:0.6)) – Factor to reduce the learning rate.lr_patience (
int(default:30)) – Number of epochs with no improvement after which the learning rate will be reduced.lr_threshold (
float(default:0.0)) – Threshold for measuring the new optimum.lr_scheduler_metric (
Literal['elbo_validation','reconstruction_loss_validation','kl_local_validation'] (default:'elbo_validation')) – Which metric to track for learning rate reduction.lr_min (
float(default:0)) – Minimum learning rate allowedadversarial_classifier (
bool|Classifier(default:False)) – Whether to use adversarial classifier in the latent spacescale_adversarial_loss (
Union[float,Literal['auto']] (default:'auto')) – Scaling factor on the adversarial components of the loss. By default, adversarial loss is scaled from 1 to 0 following the opposite of kl warmup.**loss_kwargs – Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.
Attributes table#
Methods table#
Configure optimizers for adversarial training. |
|
|
Loss for adversarial classifier. |
Update the learning rate via scheduler steps. |
|
Update the learning rate via scheduler steps. |
|
|
Training step for semi-supervised training. |
Attributes#
- SemiSupervisedAdversarialTrainingPlan.training: bool#
Methods#
- SemiSupervisedAdversarialTrainingPlan.configure_optimizers()[source]#
Configure optimizers for adversarial training.
- SemiSupervisedAdversarialTrainingPlan.loss_adversarial_classifier(z, batch_index, predict_true_class=True)[source]#
Loss for adversarial classifier.
- SemiSupervisedAdversarialTrainingPlan.on_train_epoch_end()[source]#
Update the learning rate via scheduler steps.