scvi.module.Classifier#

class scvi.module.Classifier(n_input, n_hidden=128, n_labels=5, n_layers=1, dropout_rate=0.1, logits=False, use_batch_norm=True, use_layer_norm=False, activation_fn=<class 'torch.nn.modules.activation.ReLU'>)[source]#

Bases: Module

Basic fully-connected NN classifier.

Parameters:
  • n_input (int) – Number of input dimensions

  • n_hidden (int) – Number of hidden nodes in hidden layer

  • n_labels (int) – Numput of outputs dimensions

  • n_layers (int) – Number of hidden layers

  • dropout_rate (float) – dropout_rate for nodes

  • logits (bool) – Return logits or not

  • use_batch_norm (bool) – Whether to use batch norm in layers

  • use_layer_norm (bool) – Whether to use layer norm in layers

  • activation_fn (Module) – Valid activation function from torch.nn

Attributes table#

Methods table#

forward(x)

Forward computation.

Attributes#

training

Classifier.training: bool#

Methods#

forward

Classifier.forward(x)[source]#

Forward computation.