scvi.train.SemiSupervisedTrainingPlan#
- class scvi.train.SemiSupervisedTrainingPlan(module, classification_ratio=50, lr=0.001, weight_decay=1e-06, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', **loss_kwargs)[source]#
Bases:
scvi.train._trainingplans.TrainingPlan
Lightning module task for SemiSupervised Training.
- Parameters
- module :
BaseModuleClass
A module instance from class
BaseModuleClass
.- classification_ratio :
int
(default:50
) Weight of the classification_loss in loss function
- lr
Learning rate used for optimization
Adam
.- weight_decay
Weight decay used in
Adam
.- n_steps_kl_warmup :
int
|None
Optional
[int
] (default:None
) Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None.
- n_epochs_kl_warmup :
int
|None
Optional
[int
] (default:400
) Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None.
- reduce_lr_on_plateau :
bool
(default:False
) Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus.
- lr_factor :
float
(default:0.6
) Factor to reduce learning rate.
- lr_patience :
int
(default:30
) Number of epochs with no improvement after which learning rate will be reduced.
- lr_threshold :
float
(default:0.0
) Threshold for measuring the new optimum.
- lr_scheduler_metric : {‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’}
Literal
[‘elbo_validation’, ‘reconstruction_loss_validation’, ‘kl_local_validation’] (default:'elbo_validation'
) Which metric to track for learning rate reduction.
- **loss_kwargs
Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.
- module :
Attributes table#
Methods table#
|
Here you compute and return the training loss and some additional metrics for e.g. |
|
Operates on a single batch of data from the validation set. |
Attributes#
CHECKPOINT_HYPER_PARAMS_KEY#
- SemiSupervisedTrainingPlan.CHECKPOINT_HYPER_PARAMS_KEY = 'hyper_parameters'#
CHECKPOINT_HYPER_PARAMS_NAME#
- SemiSupervisedTrainingPlan.CHECKPOINT_HYPER_PARAMS_NAME = 'hparams_name'#
CHECKPOINT_HYPER_PARAMS_TYPE#
- SemiSupervisedTrainingPlan.CHECKPOINT_HYPER_PARAMS_TYPE = 'hparams_type'#
T_destination#
- SemiSupervisedTrainingPlan.T_destination#
alias of TypeVar(‘T_destination’, bound=
Mapping
[str
,torch.Tensor
])
alias of TypeVar(‘T_destination’, bound=Mapping
[str
, torch.Tensor
])
.. autoattribute:: SemiSupervisedTrainingPlan.T_destination
automatic_optimization
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
current_epoch#
device#
dtype#
dump_patches#
- SemiSupervisedTrainingPlan.dump_patches: bool = False#
This allows better BC support for
load_state_dict()
. Instate_dict()
, the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See_load_from_state_dict
on how to use this information in loading.If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.
example_input_array#
- SemiSupervisedTrainingPlan.example_input_array#
The example input array is a specification of what the module can consume in the
forward()
method. The return type is interpreted as follows:Single tensor: It is assumed the model takes a single argument, i.e.,
model.forward(model.example_input_array)
Tuple: The input array should be interpreted as a sequence of positional arguments, i.e.,
model.forward(*model.example_input_array)
Dict: The input array represents named keyword arguments, i.e.,
model.forward(**model.example_input_array)
- Return type
global_rank#
global_step#
hparams#
- SemiSupervisedTrainingPlan.hparams#
The collection of hyperparameters saved with
save_hyperparameters()
. It is mutable by the user. For the frozen set of initial hyperparameters, usehparams_initial
.- Returns
mutable hyperparameters dicionary
- Return type
Union[AttributeDict, dict, Namespace]
hparams_initial#
kl_weight#
- SemiSupervisedTrainingPlan.kl_weight#
Scaling factor on KL divergence during training.
loaded_optimizer_states_dict#
local_rank#
logger#
- SemiSupervisedTrainingPlan.logger#
Reference to the logger object in the Trainer.
model_size#
n_obs_training#
- SemiSupervisedTrainingPlan.n_obs_training#
Number of observations in the training set.
This will update the loss kwargs for loss rescaling.
Notes
This can get set after initialization
n_obs_validation#
- SemiSupervisedTrainingPlan.n_obs_validation#
Number of observations in the validation set.
This will update the loss kwargs for loss rescaling.
Notes
This can get set after initialization
on_gpu#
- SemiSupervisedTrainingPlan.on_gpu#
Returns
True
if this model is currently located on a GPU.Useful to set flags around the LightningModule for different CPU vs GPU behavior.
truncated_bptt_steps#
- SemiSupervisedTrainingPlan.truncated_bptt_steps#
Enables Truncated Backpropagation Through Time in the Trainer when set to a positive integer.
It represents the number of times
training_step()
gets called before backpropagation. If this is > 0, thetraining_step()
receives an additional argumenthiddens
and is expected to return a hidden state.- Return type
training#
Methods#
training_step#
- SemiSupervisedTrainingPlan.training_step(batch, batch_idx, optimizer_idx=0)[source]#
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters
- batch :
Tensor
| (Tensor
, …) | [Tensor
, …] The output of your
DataLoader
. A tensor, tuple or list.- batch_idx :
int
Integer displaying index of this batch
- optimizer_idx :
int
When using multiple optimizers, this argument will also be present.
- hiddens :
Any
Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
- batch :
- Returns
Any of.
Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
None
- Training will skip to the next batch. This is only for automatic optimization.This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
If you define multiple optimizers, this step will be called with an additional
optimizer_idx
parameter.# Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder ... if optimizer_idx == 1: # do training_step with decoder ...
If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.
# Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step out, hiddens = self.lstm(data, hiddens) loss = ... return {"loss": loss, "hiddens": hiddens}
Note
The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.
validation_step#
- SemiSupervisedTrainingPlan.validation_step(batch, batch_idx, optimizer_idx=0)[source]#
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(val_batch) val_outs.append(out) validation_epoch_end(val_outs)
- Parameters
- Returns
Any object or value
None
- Validation will skip to the next batch
# pseudocode of order val_outs = [] for val_batch in val_data: out = validation_step(val_batch) if defined("validation_step_end"): out = validation_step_end(out) val_outs.append(out) val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()
will have an additional argument.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx): # dataloader_idx tells you which dataset this is. ...
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.