scvi.train.SemiSupervisedTrainingPlan#

class scvi.train.SemiSupervisedTrainingPlan(module, classification_ratio=50, lr=0.001, weight_decay=1e-06, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', **loss_kwargs)[source]#

Bases: TrainingPlan

Lightning module task for SemiSupervised Training.

Parameters:
module : BaseModuleClass

A module instance from class BaseModuleClass.

classification_ratio : int (default: 50)

Weight of the classification_loss in loss function

lr

Learning rate used for optimization Adam.

weight_decay

Weight decay used in Adam.

n_steps_kl_warmup : int | NoneOptional[int] (default: None)

Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.

n_epochs_kl_warmup : int | NoneOptional[int] (default: 400)

Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.

reduce_lr_on_plateau : bool (default: False)

Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.

lr_factor : float (default: 0.6)

Factor to reduce learning rate.

lr_patience : int (default: 30)

Number of epochs with no improvement after which learning rate will be reduced.

lr_threshold : float (default: 0.0)

Threshold for measuring the new optimum.

lr_scheduler_metric : {‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’}Literal[‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’] (default: 'elbo_validation')

Which metric to track for learning rate reduction.

**loss_kwargs

Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.

Attributes table#

Methods table#

training_step(batch, batch_idx[, optimizer_idx])

Here you compute and return the training loss and some additional metrics for e.g.

validation_step(batch, batch_idx[, ...])

Operates on a single batch of data from the validation set.

Attributes#

CHECKPOINT_HYPER_PARAMS_KEY#

SemiSupervisedTrainingPlan.CHECKPOINT_HYPER_PARAMS_KEY = 'hyper_parameters'#

CHECKPOINT_HYPER_PARAMS_NAME#

SemiSupervisedTrainingPlan.CHECKPOINT_HYPER_PARAMS_NAME = 'hparams_name'#

CHECKPOINT_HYPER_PARAMS_TYPE#

SemiSupervisedTrainingPlan.CHECKPOINT_HYPER_PARAMS_TYPE = 'hparams_type'#

T_destination#

SemiSupervisedTrainingPlan.T_destination#

alias of TypeVar(‘T_destination’, bound=Mapping[str, Tensor])

alias of TypeVar(‘T_destination’, bound=Mapping[str, Tensor]) .. autoattribute:: SemiSupervisedTrainingPlan.T_destination automatic_optimization ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

SemiSupervisedTrainingPlan.automatic_optimization#

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

Return type:

bool

current_epoch#

SemiSupervisedTrainingPlan.current_epoch#

The current epoch in the Trainer.

If no Trainer is attached, this propery is 0.

Return type:

int

device#

SemiSupervisedTrainingPlan.device#
Return type:

str | deviceUnion[str, device]

dtype#

SemiSupervisedTrainingPlan.dtype#
Return type:

str | dtypeUnion[str, dtype]

dump_patches#

SemiSupervisedTrainingPlan.dump_patches: bool = False#

This allows better BC support for load_state_dict(). In state_dict(), the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See _load_from_state_dict on how to use this information in loading.

If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.

example_input_array#

SemiSupervisedTrainingPlan.example_input_array#

The example input array is a specification of what the module can consume in the forward() method. The return type is interpreted as follows:

  • Single tensor: It is assumed the model takes a single argument, i.e., model.forward(model.example_input_array)

  • Tuple: The input array should be interpreted as a sequence of positional arguments, i.e., model.forward(*model.example_input_array)

  • Dict: The input array represents named keyword arguments, i.e., model.forward(**model.example_input_array)

Return type:

Any

global_rank#

SemiSupervisedTrainingPlan.global_rank#

The index of the current process across all nodes and devices.

Return type:

int

global_step#

SemiSupervisedTrainingPlan.global_step#

Total training batches seen across all epochs.

If no Trainer is attached, this propery is 0.

Return type:

int

hparams#

SemiSupervisedTrainingPlan.hparams#

The collection of hyperparameters saved with save_hyperparameters(). It is mutable by the user. For the frozen set of initial hyperparameters, use hparams_initial.

Returns:

mutable hyperparameters dicionary

Return type:

Union[AttributeDict, dict, Namespace]

hparams_initial#

SemiSupervisedTrainingPlan.hparams_initial#

The collection of hyperparameters saved with save_hyperparameters(). These contents are read-only. Manual updates to the saved hyperparameters can instead be performed through hparams.

Returns:

immutable initial hyperparameters

Return type:

AttributeDict

kl_weight#

SemiSupervisedTrainingPlan.kl_weight#

Scaling factor on KL divergence during training.

loaded_optimizer_states_dict#

SemiSupervisedTrainingPlan.loaded_optimizer_states_dict#
Return type:

dict

local_rank#

SemiSupervisedTrainingPlan.local_rank#

The index of the current process within a single node.

Return type:

int

logger#

SemiSupervisedTrainingPlan.logger#

Reference to the logger object in the Trainer.

model_size#

SemiSupervisedTrainingPlan.model_size#

Returns the model size in MegaBytes (MB)

Note

This property will not return correct value for Deepspeed (stage 3) and fully-sharded training.

Return type:

float

n_obs_training#

SemiSupervisedTrainingPlan.n_obs_training#

Number of observations in the training set.

This will update the loss kwargs for loss rescaling.

Notes

This can get set after initialization

n_obs_validation#

SemiSupervisedTrainingPlan.n_obs_validation#

Number of observations in the validation set.

This will update the loss kwargs for loss rescaling.

Notes

This can get set after initialization

on_gpu#

SemiSupervisedTrainingPlan.on_gpu#

Returns True if this model is currently located on a GPU.

Useful to set flags around the LightningModule for different CPU vs GPU behavior.

truncated_bptt_steps#

SemiSupervisedTrainingPlan.truncated_bptt_steps#

Enables Truncated Backpropagation Through Time in the Trainer when set to a positive integer.

It represents the number of times training_step() gets called before backpropagation. If this is > 0, the training_step() receives an additional argument hiddens and is expected to return a hidden state.

Return type:

int

training#

SemiSupervisedTrainingPlan.training: bool#

Methods#

training_step#

SemiSupervisedTrainingPlan.training_step(batch, batch_idx, optimizer_idx=0)[source]#

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
batch : Tensor | (Tensor, …) | [Tensor, …]

The output of your DataLoader. A tensor, tuple or list.

batch_idx : int

Integer displaying index of this batch

optimizer_idx : int

When using multiple optimizers, this argument will also be present.

hiddens : Any

Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.

Returns:

Any of.

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'

  • None - Training will skip to the next batch. This is only for automatic optimization.

    This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

If you define multiple optimizers, this step will be called with an additional optimizer_idx parameter.

# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
    if optimizer_idx == 0:
        # do training_step with encoder
        ...
    if optimizer_idx == 1:
        # do training_step with decoder
        ...

If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.

# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
    # hiddens are the hidden states from the previous truncated backprop step
    out, hiddens = self.lstm(data, hiddens)
    loss = ...
    return {"loss": loss, "hiddens": hiddens}

Note

The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.

validation_step#

SemiSupervisedTrainingPlan.validation_step(batch, batch_idx, optimizer_idx=0)[source]#

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    val_outs.append(out)
validation_epoch_end(val_outs)
Parameters:
batch : Tensor | (Tensor, …) | [Tensor, …]

The output of your DataLoader. A tensor, tuple or list.

batch_idx : int

The index of this batch

dataloader_idx : int

The index of the dataloader that produced this batch (only if multiple val dataloaders used)

Returns:

  • Any object or value

  • None - Validation will skip to the next batch

# pseudocode of order
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    if defined("validation_step_end"):
        out = validation_step_end(out)
    val_outs.append(out)
val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader:
def validation_step(self, batch, batch_idx):
    ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx):
    ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx):
    # dataloader_idx tells you which dataset this is.
    ...

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.