scvi.module.Classifier#
- class scvi.module.Classifier(n_input, n_hidden=128, n_labels=5, n_layers=1, dropout_rate=0.1, logits=False, use_batch_norm=True, use_layer_norm=False, activation_fn=<class 'torch.nn.modules.activation.ReLU'>)[source]#
Bases:
torch.nn.modules.module.Module
Basic fully-connected NN classifier.
- Parameters
- n_input :
int
Number of input dimensions
- n_hidden :
int
(default:128
) Number of hidden nodes in hidden layer
- n_labels :
int
(default:5
) Numput of outputs dimensions
- n_layers :
int
(default:1
) Number of hidden layers
- dropout_rate :
float
(default:0.1
) dropout_rate for nodes
- logits :
bool
(default:False
) Return logits or not
- use_batch_norm :
bool
(default:True
) Whether to use batch norm in layers
- use_layer_norm :
bool
(default:False
) Whether to use layer norm in layers
- activation_fn :
Module
(default:<class 'torch.nn.modules.activation.ReLU'>
) Valid activation function from torch.nn
- n_input :
Attributes table#
Methods table#
|
Defines the computation performed at every call. |
Attributes#
T_destination#
- Classifier.T_destination#
alias of TypeVar(‘T_destination’, bound=
Mapping
[str
,torch.Tensor
])
alias of TypeVar(‘T_destination’, bound=Mapping
[str
, torch.Tensor
])
.. autoattribute:: Classifier.T_destination
dump_patches
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- Classifier.dump_patches: bool = False#
This allows better BC support for
load_state_dict()
. Instate_dict()
, the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See_load_from_state_dict
on how to use this information in loading.If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.
training#
Methods#
forward#
- Classifier.forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.