scvi.train.ClassifierTrainingPlan#

class scvi.train.ClassifierTrainingPlan(classifier, *, lr=0.001, weight_decay=1e-06, eps=0.01, optimizer='Adam', data_key='X', labels_key='labels', loss=<class 'torch.nn.modules.loss.CrossEntropyLoss'>)[source]#

Bases: LightningModule

Lightning module task to train a simple MLP classifier.

Parameters:
  • classifier (BaseModuleClass) – A model instance from Classifier.

  • lr (float (default: 0.001)) – Learning rate used for optimization.

  • weight_decay (float (default: 1e-06)) – Weight decay used in optimization.

  • eps (float (default: 0.01)) – eps used for optimization.

  • optimizer (Literal['Adam', 'AdamW'] (default: 'Adam')) – One of “Adam” (Adam), “AdamW” (AdamW).

  • data_key (str (default: 'X')) – Key for classifier input in tensor dict minibatch

  • labels_key (str (default: 'labels')) – Key for classifier label in tensor dict minibatch

  • loss (Callable (default: <class 'torch.nn.modules.loss.CrossEntropyLoss'>)) – PyTorch loss to use

Attributes table#

Methods table#

configure_optimizers()

Configure optimizers for classifier training.

forward(*args, **kwargs)

Passthrough to the module's forward function.

training_step(batch, batch_idx)

Training step for classifier training.

validation_step(batch, batch_idx)

Validation step for classifier training.

Attributes#

ClassifierTrainingPlan.training: bool#

Methods#

ClassifierTrainingPlan.configure_optimizers()[source]#

Configure optimizers for classifier training.

ClassifierTrainingPlan.forward(*args, **kwargs)[source]#

Passthrough to the module’s forward function.

ClassifierTrainingPlan.training_step(batch, batch_idx)[source]#

Training step for classifier training.

ClassifierTrainingPlan.validation_step(batch, batch_idx)[source]#

Validation step for classifier training.