scvi.module.Classifier#

class scvi.module.Classifier(n_input, n_hidden=128, n_labels=5, n_layers=1, dropout_rate=0.1, logits=False, use_batch_norm=True, use_layer_norm=False, activation_fn=<class 'torch.nn.modules.activation.ReLU'>)[source]#

Bases: Module

Basic fully-connected NN classifier.

Parameters:
  • n_input (int) – Number of input dimensions

  • n_hidden (int (default: 128)) – Number of hidden nodes in hidden layer

  • n_labels (int (default: 5)) – Numput of outputs dimensions

  • n_layers (int (default: 1)) – Number of hidden layers

  • dropout_rate (float (default: 0.1)) – dropout_rate for nodes

  • logits (bool (default: False)) – Return logits or not

  • use_batch_norm (bool (default: True)) – Whether to use batch norm in layers

  • use_layer_norm (bool (default: False)) – Whether to use layer norm in layers

  • activation_fn (Module (default: <class 'torch.nn.modules.activation.ReLU'>)) – Valid activation function from torch.nn

Attributes table#

Methods table#

forward(x)

Forward computation.

Attributes#

training

Classifier.training: bool#

Methods#

forward

Classifier.forward(x)[source]#

Forward computation.