scvi.train.SemiSupervisedTrainingPlan#
- class scvi.train.SemiSupervisedTrainingPlan(module, n_classes, *, classification_ratio=50, lr=0.001, weight_decay=1e-06, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', compile=False, compile_kwargs=None, **loss_kwargs)[source]#
Bases:
TrainingPlan
Lightning module task for SemiSupervised Training.
- Parameters:
module (
BaseModuleClass
) – A module instance from classBaseModuleClass
.n_classes (
int
) – The number of classes in the labeled dataset.classification_ratio (
int
(default:50
)) – Weight of the classification_loss in loss functionlr (
float
(default:0.001
)) – Learning rate used for optimizationAdam
.weight_decay (
float
(default:1e-06
)) – Weight decay used inAdam
.n_steps_kl_warmup (
int
|None
(default:None
)) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.n_epochs_kl_warmup (
int
|None
(default:400
)) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.reduce_lr_on_plateau (
bool
(default:False
)) – Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.lr_factor (
float
(default:0.6
)) – Factor to reduce learning rate.lr_patience (
int
(default:30
)) – Number of epochs with no improvement after which learning rate will be reduced.lr_threshold (
float
(default:0.0
)) – Threshold for measuring the new optimum.lr_scheduler_metric (
Literal
['elbo_validation'
,'reconstruction_loss_validation'
,'kl_local_validation'
] (default:'elbo_validation'
)) – Which metric to track for learning rate reduction.**loss_kwargs – Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.
Attributes table#
Methods table#
|
Computes and logs metrics. |
|
Log with mode. |
|
Training step for semi-supervised training. |
|
Validation step for semi-supervised training. |
Attributes#
- SemiSupervisedTrainingPlan.training: bool#
Methods#
- SemiSupervisedTrainingPlan.compute_and_log_metrics(loss_output, metrics, mode)[source]#
Computes and logs metrics.