scvi.train.ClassifierTrainingPlan#
- class scvi.train.ClassifierTrainingPlan(classifier, *, lr=0.001, weight_decay=1e-06, eps=0.01, optimizer='Adam', data_key='X', labels_key='labels', loss=<class 'torch.nn.modules.loss.CrossEntropyLoss'>)[source]#
Bases:
LightningModuleLightning module task to train a simple MLP classifier.
- Parameters:
classifier (
BaseModuleClass) – A model instance fromClassifier.lr (
float(default:0.001)) – Learning rate used for optimization.weight_decay (
float(default:1e-06)) – Weight decay used in optimization.eps (
float(default:0.01)) – eps used for optimization.optimizer (
Literal['Adam','AdamW'] (default:'Adam')) – One of “Adam” (Adam), “AdamW” (AdamW).data_key (
str(default:'X')) – Key for classifier input in tensor dict minibatchlabels_key (
str(default:'labels')) – Key for classifier label in tensor dict minibatchloss (
Callable(default:<class 'torch.nn.modules.loss.CrossEntropyLoss'>)) – PyTorch loss to use
Attributes table#
Methods table#
Configure optimizers for classifier training. |
|
|
Passthrough to the module's forward function. |
|
Training step for classifier training. |
|
Validation step for classifier training. |
Attributes#
- ClassifierTrainingPlan.training: bool#
Methods#
- ClassifierTrainingPlan.configure_optimizers()[source]#
Configure optimizers for classifier training.
- ClassifierTrainingPlan.forward(*args, **kwargs)[source]#
Passthrough to the module’s forward function.