scvi.module.Classifier#

class scvi.module.Classifier(n_input, n_hidden=128, n_labels=5, n_layers=1, dropout_rate=0.1, logits=False, use_batch_norm=True, use_layer_norm=False, activation_fn=<class 'torch.nn.modules.activation.ReLU'>)[source]#

Bases: torch.nn.modules.module.Module

Basic fully-connected NN classifier.

Parameters
n_input : int

Number of input dimensions

n_hidden : int (default: 128)

Number of hidden nodes in hidden layer

n_labels : int (default: 5)

Numput of outputs dimensions

n_layers : int (default: 1)

Number of hidden layers

dropout_rate : float (default: 0.1)

dropout_rate for nodes

logits : bool (default: False)

Return logits or not

use_batch_norm : bool (default: True)

Whether to use batch norm in layers

use_layer_norm : bool (default: False)

Whether to use layer norm in layers

activation_fn : Module (default: <class 'torch.nn.modules.activation.ReLU'>)

Valid activation function from torch.nn

Attributes table#

Methods table#

forward(x)

Defines the computation performed at every call.

Attributes#

T_destination#

Classifier.T_destination#

alias of TypeVar(‘T_destination’, bound=Mapping[str, torch.Tensor])

alias of TypeVar(‘T_destination’, bound=Mapping[str, torch.Tensor]) .. autoattribute:: Classifier.T_destination dump_patches ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Classifier.dump_patches: bool = False#

This allows better BC support for load_state_dict(). In state_dict(), the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See _load_from_state_dict on how to use this information in loading.

If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.

training#

Classifier.training: bool#

Methods#

forward#

Classifier.forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.