scvi.train.SemiSupervisedTrainingPlan#
- class scvi.train.SemiSupervisedTrainingPlan(module, *, classification_ratio=50, lr=0.001, weight_decay=1e-06, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', **loss_kwargs)[source]#
Bases:
TrainingPlan
Lightning module task for SemiSupervised Training.
- Parameters:
module (BaseModuleClass) – A module instance from class
BaseModuleClass
.classification_ratio (int) – Weight of the classification_loss in loss function
n_steps_kl_warmup (Optional[int]) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when
n_epochs_kl_warmup
is set to None.n_epochs_kl_warmup (Optional[int]) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides
n_steps_kl_warmup
when both are notNone
.reduce_lr_on_plateau (bool) – Whether to monitor validation loss and reduce learning rate when validation set
lr_scheduler_metric
plateaus.lr_factor (float) – Factor to reduce learning rate.
lr_patience (int) – Number of epochs with no improvement after which learning rate will be reduced.
lr_threshold (float) – Threshold for measuring the new optimum.
lr_scheduler_metric (Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation']) – Which metric to track for learning rate reduction.
**loss_kwargs – Keyword args to pass to the loss method of the
module
.kl_weight
should not be passed here and is handled automatically.
Attributes table#
Methods table#
|
Training step for semi-supervised training. |
|
Validation step for semi-supervised training. |
Attributes#
training
Methods#
training_step
- SemiSupervisedTrainingPlan.training_step(batch, batch_idx, optimizer_idx=0)[source]#
Training step for semi-supervised training.
validation_step