scvi.module.Classifier

class scvi.module.Classifier(n_input, n_hidden=128, n_labels=5, n_layers=1, dropout_rate=0.1, logits=False, use_batch_norm=True, use_layer_norm=False, activation_fn=<class 'torch.nn.modules.activation.ReLU'>)[source]

Bases: torch.nn.modules.module.Module

Basic fully-connected NN classifier.

Parameters
n_input : intint

Number of input dimensions

n_hidden : intint (default: 128)

Number of hidden nodes in hidden layer

n_labels : intint (default: 5)

Numput of outputs dimensions

n_layers : intint (default: 1)

Number of hidden layers

dropout_rate : floatfloat (default: 0.1)

dropout_rate for nodes

logits : boolbool (default: False)

Return logits or not

use_batch_norm : boolbool (default: True)

Whether to use batch norm in layers

use_layer_norm : boolbool (default: False)

Whether to use layer norm in layers

activation_fn : ModuleModule (default: <class 'torch.nn.modules.activation.ReLU'>)

Valid activation function from torch.nn

Attributes

Methods

forward(x)

Defines the computation performed at every call.