scvi.module.Classifier#
- class scvi.module.Classifier(n_input, n_hidden=128, n_labels=5, n_layers=1, dropout_rate=0.1, logits=False, use_batch_norm=True, use_layer_norm=False, activation_fn=<class 'torch.nn.modules.activation.ReLU'>)[source]#
Bases:
Module
Basic fully-connected NN classifier.
- Parameters:
n_input (int) – Number of input dimensions
n_hidden (int) – Number of hidden nodes in hidden layer
n_labels (int) – Numput of outputs dimensions
n_layers (int) – Number of hidden layers
dropout_rate (float) – dropout_rate for nodes
logits (bool) – Return logits or not
use_batch_norm (bool) – Whether to use batch norm in layers
use_layer_norm (bool) – Whether to use layer norm in layers
activation_fn (Module) – Valid activation function from torch.nn
Attributes table#
Methods table#
|
Forward computation. |
Attributes#
training
Methods#
forward