class scvi.train.SemiSupervisedTrainingPlan(module, n_classes, *, classification_ratio=50, lr=0.001, weight_decay=1e-06, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', **loss_kwargs)[source]#

Bases: TrainingPlan

Lightning module task for SemiSupervised Training.

  • module (BaseModuleClass) – A module instance from class BaseModuleClass.

  • n_classes (int) – The number of classes in the labeled dataset.

  • classification_ratio (int (default: 50)) – Weight of the classification_loss in loss function

  • lr (float (default: 0.001)) – Learning rate used for optimization Adam.

  • weight_decay (float (default: 1e-06)) – Weight decay used in Adam.

  • n_steps_kl_warmup (Optional[int] (default: None)) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

  • n_epochs_kl_warmup (Optional[int] (default: 400)) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

  • reduce_lr_on_plateau (bool (default: False)) – Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.

  • lr_factor (float (default: 0.6)) – Factor to reduce learning rate.

  • lr_patience (int (default: 30)) – Number of epochs with no improvement after which learning rate will be reduced.

  • lr_threshold (float (default: 0.0)) – Threshold for measuring the new optimum.

  • lr_scheduler_metric (Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation'] (default: 'elbo_validation')) – Which metric to track for learning rate reduction.

  • **loss_kwargs – Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.

Attributes table#


Methods table#

compute_and_log_metrics(loss_output, ...)

Computes and logs metrics.

log_with_mode(key, value, mode, **kwargs)

Log with mode.

training_step(batch, batch_idx)

Training step for semi-supervised training.

validation_step(batch, batch_idx)

Validation step for semi-supervised training.

Attributes# bool#


SemiSupervisedTrainingPlan.compute_and_log_metrics(loss_output, metrics, mode)[source]#

Computes and logs metrics.

SemiSupervisedTrainingPlan.log_with_mode(key, value, mode, **kwargs)[source]#

Log with mode.

SemiSupervisedTrainingPlan.training_step(batch, batch_idx)[source]#

Training step for semi-supervised training.

SemiSupervisedTrainingPlan.validation_step(batch, batch_idx)[source]#

Validation step for semi-supervised training.