scvi.train.SemiSupervisedTrainingPlan

class scvi.train.SemiSupervisedTrainingPlan(module, classification_ratio=50, lr=0.001, weight_decay=1e-06, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', **loss_kwargs)[source]

Bases: scvi.train._trainingplans.TrainingPlan

Lightning module task for SemiSupervised Training.

Parameters
module : BaseModuleClassBaseModuleClass

A module instance from class BaseModuleClass.

classification_ratio : intint (default: 50)

Weight of the classification_loss in loss function

lr

Learning rate used for optimization Adam.

weight_decay

Weight decay used in Adam.

n_steps_kl_warmup : int | NoneOptional[int] (default: None)

Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

n_epochs_kl_warmup : int | NoneOptional[int] (default: 400)

Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

reduce_lr_on_plateau : boolbool (default: False)

Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.

lr_factor : floatfloat (default: 0.6)

Factor to reduce learning rate.

lr_patience : intint (default: 30)

Number of epochs with no improvement after which learning rate will be reduced.

lr_threshold : floatfloat (default: 0.0)

Threshold for measuring the new optimum.

lr_scheduler_metric : {‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’}Literal[‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’] (default: 'elbo_validation')

Which metric to track for learning rate reduction.

**loss_kwargs

Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.

Methods

training_epoch_end(outputs)

Called at the end of the training epoch with the outputs of all training steps.

training_step(batch, batch_idx[, optimizer_idx])

Here you compute and return the training loss and some additional metrics for e.g.

validation_epoch_end(outputs)

Aggregate validation step information.

validation_step(batch, batch_idx[, …])

Operates on a single batch of data from the validation set.