Classifier

class scvi.models.Classifier(n_input, n_hidden=128, n_labels=5, n_layers=1, dropout_rate=0.1, logits=False)[source]

Bases: torch.nn.modules.module.Module

Basic fully-connected NN classifier

Methods Summary

forward(x)

Methods Documentation

forward(x)[source]