Constructing a high-level model#

Note

Running the following cell will install tutorial dependencies on Google Colab only. It will have no effect on environments other than Google Colab.

!pip install --quiet scvi-colab
from scvi_colab import install

install()
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

import os
import tempfile
from collections.abc import Sequence
from typing import Optional

import numpy as np
import scvi
import torch
from anndata import AnnData
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
    CategoricalJointObsField,
    CategoricalObsField,
    LayerField,
    NumericalJointObsField,
    NumericalObsField,
)
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin
from scvi.module import VAE
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)
Last run with scvi-tools version: 1.1.0

Note

You can modify save_dir below to change where the data files for this tutorial are saved.

torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()

%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"

At this point we have covered

  1. Data registration via setup_anndata and dataloaders via AnnDataLoader

  2. Building a probabilistic model by subclassing BaseModuleClass

In this tutorial, we will cover the highest-level classes in scvi-tools: the model classes. The main purpose of these classes (e.g., scvi.model.SCVI) is to wrap the actions of module instantiation, training, and subsequent posterior queries of our module into a convenient interface. These model classes are the fundamental objects driving scientific analysis of data with scvi-tools. Out of convention, we will refer to these objects as “models” and the lower-level objects presented in the previous tutorial as “modules”.

A simple model class#

Here we will walkthrough an example of building the scvi.model.SCVI class. We will progressively add functionality to the class.

Sketch of BaseModelClass#

Let us start by providing a high level overview of BaseModelClass that we will inherit. Note that this is pseudocode to provide intuition. We see that BaseModelClass contains some unverisally applicable methods, and some private methods (conventionally starting with _ in Python) that will become useful after training the model.

class MyModel(UnsupervisedTrainingMixin, BaseModelClass):
    def __init__(self, adata):
        # sets some basic attributes like is_trained_
        # record the setup_dict registered in the adata
        self.adata = adata
        self.scvi_setup_dict_ = adata.uns["_scvi"]
        self.summary_stats = self.scvi_setup_dict_["summary_stats"]

    def _validate_anndata(self, adata):
        # check that anndata is equivalent by comparing
        # to the initial setup_dict
        pass

    def _make_dataloader(self, adata):
        # return a dataloader to iterate over adata
        pass

    def train(self):
        # Universal train method provided by UnsupservisedTrainingMixin
        # BaseModelClass does not come with train
        # In general train methods are straightforward to compose manually
        pass

    def save(self):
        # universal save method
        # saves modules, anndata setup dict, and attributes ending with _
        pass

    def load(self):
        # universal load method
        pass

    @classmethod
    def setup_anndata(cls, adata):
        # anndata registration method
        pass

Baseline version of SCVI class#

Let’s now create the simplest possible version of the SCVI class. We inherit the BaseModelClass, and write our __init__ method.

We take care to do the following:

  1. Set the module attribute to be equal to our VAE module, which here is the torch-level version of scVI.

  2. Add a _model_summary_string attr, which will be used as a representation for the model.

  3. Run self.init_params_ = self._get_init_params(locals()), which stores the arguments used to initialize the model, facilitating saving/loading of the model later.

To initialize the VAE, we can use the information in self.summary_stats, which is information that was stored in the anndata object at setup_anndata() time. In this example, we have only exposed n_latent to users through SCVI. In practice, we try to expose only the most relevant parameters, as all other parameters can be accessed by passing model_kwargs.

Finally, we write the setup_anndata class method, which is used to register the appropriate matrices inside the anndata that we will use to load data into the model. This method uses the AnnDataManager class to orchestrate the data registration process. More details about the AnnDataManager can be learned in our data handling tutorial.

class SCVI(UnsupervisedTrainingMixin, BaseModelClass):
    """single-cell Variational Inference [Lopez18]_."""

    def __init__(
        self,
        adata: AnnData,
        n_latent: int = 10,
        **model_kwargs,
    ):
        super().__init__(adata)

        self.module = VAE(
            n_input=self.summary_stats["n_vars"],
            n_batch=self.summary_stats["n_batch"],
            n_latent=n_latent,
            **model_kwargs,
        )
        self._model_summary_string = (
            f"SCVI Model with the following params: \nn_latent: {n_latent}"
        )
        self.init_params_ = self._get_init_params(locals())

    @classmethod
    def setup_anndata(
        cls,
        adata: AnnData,
        batch_key: Optional[str] = None,
        layer: Optional[str] = None,
        **kwargs,
    ) -> Optional[AnnData]:
        setup_method_args = cls._get_setup_method_args(**locals())
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
            CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
            # Dummy fields required for VAE class.
            CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None),
            NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, None, required=False),
            CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, None),
            NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, None),
        ]
        adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
        adata_manager.register_fields(adata, **kwargs)
        cls.register_manager(adata_manager)

Now we explore what we can and cannot do with this model. Let’s get some data and initialize a SCVI instance. Of note, for testing purposes we like to use scvi.data.synthetic_iid() which returns a simple, small anndata object that was already run through setup_anndata().

adata = scvi.data.synthetic_iid()
adata
AnnData object with n_obs × n_vars = 400 × 100
    obs: 'batch', 'labels'
    uns: 'protein_names'
    obsm: 'protein_expression', 'accessibility'

Above we saw in the setup_anndata() implementation that we ended the function with cls.register_manager(adata_manager). This function stores the newly created AnnDataManager instance in a class-specific dictionary called _setup_adata_manager_store. Specifically, this maps from UUIDs (specific to each AnnData object; stored on adata.uns["_scvi_uuid"]) to AnnDataManager instances instantiated by that class’s setup_anndata() function.

On model initialization, the model instance retrieves the AnnDataManager object specific to the passed in adata.

SCVI.setup_anndata(adata, batch_key="batch")
print(f"adata UUID (assigned by setup_anndata): {adata.uns['_scvi_uuid']}")
print(f"AnnDataManager: {SCVI._setup_adata_manager_store[adata.uns['_scvi_uuid']]}")
model = SCVI(adata)
model
adata UUID (assigned by setup_anndata): e0a91eb2-7b47-490a-9e6c-9598363a7242
AnnDataManager: <scvi.data._manager.AnnDataManager object at 0x7f14aaa67e90>
SCVI Model with the following params: 
n_latent: 10
Training status: Not Trained

More AnnDataManager Details#

The AnnDataManager class stores state on data registered with scvi-tools. Since each manager is specific to a single AnnData, each model instance has an AnnDataManager instance for every AnnData object it has interfaced with. In addition to setup_anndata(), new AnnDataManager objects are created via the _validate_anndata() method when called on new AnnData objects (not the AnnData the model instance was initialized with). _validate_anndata() should be called in any method that references data on the AnnData object, via the scvi-tools data handling strategy (e.g. get_latent_representation()). Any instance-specific AnnDataManager objects are stored in a separate class-specific manager store called _per_instance_manager_store, which maps model instance UUIDs (assigned on initialization) and AnnData UUIDs to AnnDataManager instances. This avoids the issue of incorrect AnnDataManager retrieval when working with two model instances working over the same AnnData object.

print(f"model instance UUID: {model.id}")
print(f"adata UUID: {adata.uns['_scvi_uuid']}")
print(
    "AnnDataManager for adata: "
    f"{SCVI._per_instance_manager_store[model.id][adata.uns['_scvi_uuid']]}"
)  # { model instance UUID: { adata UUID: AnnDataManager } }
model instance UUID: da5c8fc0-0242-4eec-9948-e989250f75e7
adata UUID: e0a91eb2-7b47-490a-9e6c-9598363a7242
AnnDataManager for adata: <scvi.data._manager.AnnDataManager object at 0x7f14aaa67e90>
adata2 = scvi.data.synthetic_iid()
model._validate_anndata(adata2)
INFO     Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             
AnnData object with n_obs × n_vars = 400 × 100
    obs: 'batch', 'labels', '_scvi_batch', '_scvi_labels'
    uns: 'protein_names', '_scvi_uuid', '_scvi_manager_uuid'
    obsm: 'protein_expression', 'accessibility'
print(f"adata2 UUID: {adata.uns['_scvi_uuid']}")
print(
    f"Model instance specific manager store: {SCVI._per_instance_manager_store[model.id]}"
)  # { model instance UUID: { adata UUID: AnnDataManager } }
adata2 UUID: e0a91eb2-7b47-490a-9e6c-9598363a7242
Model instance specific manager store: {'e0a91eb2-7b47-490a-9e6c-9598363a7242': <scvi.data._manager.AnnDataManager object at 0x7f14aaa67e90>, '737626ad-74e9-4113-8310-653734a17f31': <scvi.data._manager.AnnDataManager object at 0x7f15fa433ad0>}

Additionally, the data registration process can modify or add data on the AnnData object directly. As a result, if calls between two models are interleaved, it is possible that we refer to fields created by another model instance’s data registration incorrectly. In order to avoid this, _validate_anndata() additionally checks the AnnData object for an AnnDataManager-specific UUID stored in adata.uns['_scvi_manager_uuid']. If this UUID is inconsistent with the AnnDataManager fetched from the manager store, this means the data registration must be replayed on the AnnData object before referencing any data on the AnnData. This is automatically done in _validate_anndata().

As a result, we can interleave method calls on two model instances without worrying about this clobbering issue.

SCVI.setup_anndata(adata, batch_key=None)  # No batch correction.
model2 = SCVI(adata)
print(f"Manager UUID: {model2.adata_manager.id}")
print(f"Last setup with manager UUID: {adata.uns['_scvi_manager_uuid']}")
print(f"Encoded batch obs field: {adata.obs['_scvi_batch']}")
Manager UUID: 8808df2c-9024-4dbd-b165-fdc1cd03f7d6
Last setup with manager UUID: 8808df2c-9024-4dbd-b165-fdc1cd03f7d6
Encoded batch obs field: 0      0
1      0
2      0
3      0
4      0
      ..
395    0
396    0
397    0
398    0
399    0
Name: _scvi_batch, Length: 400, dtype: int8
model._validate_anndata(adata)  # Replays registration on adata
print(f"Manager UUID: {model.adata_manager.id}")
print(f"Last setup with manager UUID: {adata.uns['_scvi_manager_uuid']}")
print(f"Encoded batch obs field: {adata.obs['_scvi_batch']}")
Manager UUID: 656a8a4d-0c20-4922-9a4a-32ecd64631b1
Last setup with manager UUID: 656a8a4d-0c20-4922-9a4a-32ecd64631b1
Encoded batch obs field: 0      0
1      0
2      0
3      0
4      0
      ..
395    1
396    1
397    1
398    1
399    1
Name: _scvi_batch, Length: 400, dtype: int8

The train method#

A model can be trained simply by calling the train method.

model.train(max_epochs=20)
Training:   0%|          | 0/20 [00:00<?, ?it/s]
Epoch 1/20:   0%|          | 0/20 [00:00<?, ?it/s]
Epoch 1/20:   5%|▌         | 1/20 [00:00<00:10,  1.76it/s]
Epoch 1/20:   5%|▌         | 1/20 [00:00<00:10,  1.76it/s, v_num=1, train_loss_step=333, train_loss_epoch=335]
Epoch 2/20:   5%|▌         | 1/20 [00:00<00:10,  1.76it/s, v_num=1, train_loss_step=333, train_loss_epoch=335]
Epoch 2/20:  10%|█         | 2/20 [00:00<00:10,  1.76it/s, v_num=1, train_loss_step=325, train_loss_epoch=327]
Epoch 3/20:  10%|█         | 2/20 [00:00<00:10,  1.76it/s, v_num=1, train_loss_step=325, train_loss_epoch=327]
Epoch 3/20:  15%|█▌        | 3/20 [00:00<00:09,  1.76it/s, v_num=1, train_loss_step=320, train_loss_epoch=322]
Epoch 4/20:  15%|█▌        | 3/20 [00:00<00:09,  1.76it/s, v_num=1, train_loss_step=320, train_loss_epoch=322]
Epoch 4/20:  20%|██        | 4/20 [00:00<00:09,  1.76it/s, v_num=1, train_loss_step=317, train_loss_epoch=318]
Epoch 5/20:  20%|██        | 4/20 [00:00<00:09,  1.76it/s, v_num=1, train_loss_step=317, train_loss_epoch=318]
Epoch 5/20:  25%|██▌       | 5/20 [00:00<00:08,  1.76it/s, v_num=1, train_loss_step=316, train_loss_epoch=316]
Epoch 6/20:  25%|██▌       | 5/20 [00:00<00:08,  1.76it/s, v_num=1, train_loss_step=316, train_loss_epoch=316]
Epoch 6/20:  30%|███       | 6/20 [00:00<00:07,  1.76it/s, v_num=1, train_loss_step=314, train_loss_epoch=314]
Epoch 7/20:  30%|███       | 6/20 [00:00<00:07,  1.76it/s, v_num=1, train_loss_step=314, train_loss_epoch=314]
Epoch 7/20:  35%|███▌      | 7/20 [00:00<00:07,  1.76it/s, v_num=1, train_loss_step=312, train_loss_epoch=312]
Epoch 8/20:  35%|███▌      | 7/20 [00:00<00:07,  1.76it/s, v_num=1, train_loss_step=312, train_loss_epoch=312]
Epoch 8/20:  40%|████      | 8/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=312, train_loss_epoch=312]
Epoch 8/20:  40%|████      | 8/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=311, train_loss_epoch=311]
Epoch 9/20:  40%|████      | 8/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=311, train_loss_epoch=311]
Epoch 9/20:  45%|████▌     | 9/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=312, train_loss_epoch=310]
Epoch 10/20:  45%|████▌     | 9/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=312, train_loss_epoch=310]
Epoch 10/20:  50%|█████     | 10/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=309, train_loss_epoch=309]
Epoch 11/20:  50%|█████     | 10/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=309, train_loss_epoch=309]
Epoch 11/20:  55%|█████▌    | 11/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=307, train_loss_epoch=308]
Epoch 12/20:  55%|█████▌    | 11/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=307, train_loss_epoch=308]
Epoch 12/20:  60%|██████    | 12/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=306, train_loss_epoch=308]
Epoch 13/20:  60%|██████    | 12/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=306, train_loss_epoch=308]
Epoch 13/20:  65%|██████▌   | 13/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=306, train_loss_epoch=307]
Epoch 14/20:  65%|██████▌   | 13/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=306, train_loss_epoch=307]
Epoch 14/20:  70%|███████   | 14/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=306, train_loss_epoch=307]
Epoch 15/20:  70%|███████   | 14/20 [00:00<00:00, 15.14it/s, v_num=1, train_loss_step=306, train_loss_epoch=307]
Epoch 15/20:  75%|███████▌  | 15/20 [00:00<00:00, 26.82it/s, v_num=1, train_loss_step=306, train_loss_epoch=307]
Epoch 15/20:  75%|███████▌  | 15/20 [00:00<00:00, 26.82it/s, v_num=1, train_loss_step=307, train_loss_epoch=306]
Epoch 16/20:  75%|███████▌  | 15/20 [00:00<00:00, 26.82it/s, v_num=1, train_loss_step=307, train_loss_epoch=306]
Epoch 16/20:  80%|████████  | 16/20 [00:00<00:00, 26.82it/s, v_num=1, train_loss_step=304, train_loss_epoch=306]
Epoch 17/20:  80%|████████  | 16/20 [00:00<00:00, 26.82it/s, v_num=1, train_loss_step=304, train_loss_epoch=306]
Epoch 17/20:  85%|████████▌ | 17/20 [00:00<00:00, 26.82it/s, v_num=1, train_loss_step=305, train_loss_epoch=306]
Epoch 18/20:  85%|████████▌ | 17/20 [00:00<00:00, 26.82it/s, v_num=1, train_loss_step=305, train_loss_epoch=306]
Epoch 18/20:  90%|█████████ | 18/20 [00:00<00:00, 26.82it/s, v_num=1, train_loss_step=306, train_loss_epoch=305]
Epoch 19/20:  90%|█████████ | 18/20 [00:00<00:00, 26.82it/s, v_num=1, train_loss_step=306, train_loss_epoch=305]
Epoch 19/20:  95%|█████████▌| 19/20 [00:00<00:00, 26.82it/s, v_num=1, train_loss_step=304, train_loss_epoch=305]
Epoch 20/20:  95%|█████████▌| 19/20 [00:00<00:00, 26.82it/s, v_num=1, train_loss_step=304, train_loss_epoch=305]
Epoch 20/20: 100%|██████████| 20/20 [00:00<00:00, 26.82it/s, v_num=1, train_loss_step=304, train_loss_epoch=304]
Epoch 20/20: 100%|██████████| 20/20 [00:00<00:00, 23.16it/s, v_num=1, train_loss_step=304, train_loss_epoch=304]

We were able to train this model, as this method is inherited in the class. Let us now take a look at psedocode of the train method of UnsupervisedTrainingMixin. The function of each of these objects is described in the API reference.

def train(
    self,
    max_epochs: Optional[int] = 100,
    train_size: float = 0.9,
    **kwargs,
):
    """Train the model."""
    # object to make train/test/val dataloaders
    data_splitter = DataSplitter(
        self.adata,
        train_size=train_size,
        validation_size=validation_size,
        batch_size=batch_size,
    )
    # defines optimizers, training step, val step, logged metrics
    training_plan = TrainingPlan(
        self.module,
        len(data_splitter.train_idx),
    )
    # creates Trainer, pre and post training procedures (Trainer.fit())
    runner = TrainRunner(
        self,
        training_plan=training_plan,
        data_splitter=data_splitter,
        max_epochs=max_epochs,
        **kwargs,
    )
    return runner()

We notice two new things:

  1. A training plan (training_plan)

  2. A train runner (runner)

The TrainRunner is a lightweight wrapper of the PyTorch lightning’s Trainer, which is a completely black-box method once a TrainingPlan is defined. So what does the TrainingPlan do?

  1. Configures optimizers (e.g., Adam), learning rate schedulers.

  2. Defines the training step, which runs a minibatch of data through the model and records the loss.

  3. Defines the validation step, same as training step, but for validation data.

  4. Records relevant metrics, such as the ELBO.

In scvi-tools we have scvi.lightning.TrainingPlan, which should cover many use cases, from VAEs and VI, to MLE and MAP estimation. Developers may find that they need a custom TrainingPlan for e.g,. multiple optimizers and complex training scheme. These can be written and used by the model class.

Developers may also overwrite this train method to add custom functionality like Early Stopping (see TOTALVI’s train method). In most cases the higher-level train method can call super().train(), which would be the BaseModelClass train method.

Save and load#

We can also save and load this model object, as it follows the expected structure.

model_dir = os.path.join(save_dir.name, "saved_model")

model.save(model_dir, save_anndata=True)
model = SCVI.load(model_dir)
INFO     File /tmp/tmpjs643jnf/saved_model/model.pt already downloaded                                             

Writing methods to query the model#

So we have a model that wraps a module that has been trained. How can we get information out of the module and present in cleanly to our users? Let’s implement a simple example: getting the latent representation out of the VAE.

This method has the following structure:

  1. Validate the user-supplied data

  2. Create a data loader

  3. Iterate over the data loader and feed into the VAE, getting the tensor of interest out of the VAE.

@torch.inference_mode()
def get_latent_representation(
    self,
    adata: Optional[AnnData] = None,
    indices: Optional[Sequence[int]] = None,
    batch_size: Optional[int] = None,
) -> np.ndarray:
    r"""Return the latent representation for each cell.

    Parameters
    ----------
    adata
        AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
        AnnData object used to initialize the model.
    indices
        Indices of cells in adata to use. If `None`, all cells are used.
    batch_size
        Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.

    Returns
    -------
    latent_representation : np.ndarray
        Low-dimensional representation for each cell
    """
    if self.is_trained_ is False:
        raise RuntimeError("Please train the model first.")

    adata = self._validate_anndata(adata)
    dataloader = self._make_dataloader(adata=adata, indices=indices, batch_size=batch_size)
    latent = []
    for tensors in dataloader:
        inference_inputs = self.module._get_inference_input(tensors)
        outputs = self.module.inference(**inference_inputs)
        qz_m = outputs["qz_m"]

        latent += [qz_m.cpu()]
    return torch.cat(latent).numpy()

Note

Validating the anndata is critical to the user experience. If None is passed it just returns the anndata used to initialize the model, but if a different object is passed, it checks that this new object is equivalent in structure to the anndata passed to the model. We took great care in engineering this function so as to allow passing anndata objects with potentially missing categories (e.g., model was trained on batches ["A", "B", "C"], but the passed anndata only has ["B", "C"]). These sorts of checks will ensure that your module will see data that it expects, and the user will get the results they expect without advanced data manipulations.

As a convention, we like to keep the module code as bare as possible and leave all posterior manipulation of module tensors to the model class methods. However, it would have been possible to write a get_z method in the module, and just have the model class that method.

Mixing in pre-coded features#

We have a number of Mixin classes that can add functionality to your model through inheritance. Here we demonstrate the VAEMixin class.

Let’s try to get the latent representation from the object we already created.

try:
    model.get_latent_representation()
except AttributeError:
    print("This function does not exist")
This function does not exist

This method becomes avaialble once the VAEMixin is inherited. Here’s an overview of the mixin methods, which are coded generally enough that they should be broadly useful to those building VAEs.

class VAEMixin:
    @torch.inference_mode()
    def get_elbo(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        batch_size: Optional[int] = None,
    ) -> float:
        pass

    @torch.inference_mode()
    def get_marginal_ll(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        n_mc_samples: int = 1000,
        batch_size: Optional[int] = None,
    ) -> float:
        pass

    @torch.inference_mode()
    def get_reconstruction_error(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        batch_size: Optional[int] = None,
    ) -> Union[float, Dict[str, float]]:
        pass

    @torch.inference_mode()
    def get_latent_representation(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        give_mean: bool = True,
        mc_samples: int = 5000,
        batch_size: Optional[int] = None,
    ) -> np.ndarray:
        pass

Let’s now inherit the mixin into our SCVI class.

class SCVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
    """single-cell Variational Inference [Lopez18]_."""

    def __init__(
        self,
        adata: AnnData,
        n_latent: int = 10,
        **model_kwargs,
    ):
        super().__init__(adata)

        self.module = VAE(
            n_input=self.summary_stats["n_vars"],
            n_batch=self.summary_stats["n_batch"],
            n_latent=n_latent,
            **model_kwargs,
        )
        self._model_summary_string = (
            f"SCVI Model with the following params: \nn_latent: {n_latent}"
        )
        self.init_params_ = self._get_init_params(locals())

    @classmethod
    def setup_anndata(
        cls,
        adata: AnnData,
        batch_key: Optional[str] = None,
        layer: Optional[str] = None,
        **kwargs,
    ) -> Optional[AnnData]:
        setup_method_args = cls._get_setup_method_args(**locals())
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
            CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
            # Dummy fields required for VAE class.
            CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None),
            NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, None, required=False),
            CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, None),
            NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, None),
        ]
        adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
        adata_manager.register_fields(adata, **kwargs)
        cls.register_manager(adata_manager)
SCVI.setup_anndata(adata, batch_key="batch")
model = SCVI(adata)
model.train(10)
model.get_latent_representation()
Training:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 1/10:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 1/10:  10%|█         | 1/10 [00:00<00:00, 59.44it/s, v_num=1, train_loss_step=331, train_loss_epoch=334]
Epoch 2/10:  10%|█         | 1/10 [00:00<00:00, 49.79it/s, v_num=1, train_loss_step=331, train_loss_epoch=334]
Epoch 2/10:  20%|██        | 2/10 [00:00<00:00, 59.36it/s, v_num=1, train_loss_step=325, train_loss_epoch=326]
Epoch 3/10:  20%|██        | 2/10 [00:00<00:00, 56.42it/s, v_num=1, train_loss_step=325, train_loss_epoch=326]
Epoch 3/10:  30%|███       | 3/10 [00:00<00:00, 61.41it/s, v_num=1, train_loss_step=319, train_loss_epoch=322]
Epoch 4/10:  30%|███       | 3/10 [00:00<00:00, 59.33it/s, v_num=1, train_loss_step=319, train_loss_epoch=322]
Epoch 4/10:  40%|████      | 4/10 [00:00<00:00, 61.51it/s, v_num=1, train_loss_step=317, train_loss_epoch=318]
Epoch 5/10:  40%|████      | 4/10 [00:00<00:00, 59.94it/s, v_num=1, train_loss_step=317, train_loss_epoch=318]
Epoch 5/10:  50%|█████     | 5/10 [00:00<00:00, 62.29it/s, v_num=1, train_loss_step=315, train_loss_epoch=316]
Epoch 6/10:  50%|█████     | 5/10 [00:00<00:00, 61.00it/s, v_num=1, train_loss_step=315, train_loss_epoch=316]
Epoch 6/10:  60%|██████    | 6/10 [00:00<00:00, 62.80it/s, v_num=1, train_loss_step=312, train_loss_epoch=314]
Epoch 7/10:  60%|██████    | 6/10 [00:00<00:00, 61.70it/s, v_num=1, train_loss_step=312, train_loss_epoch=314]
Epoch 7/10:  70%|███████   | 7/10 [00:00<00:00, 63.07it/s, v_num=1, train_loss_step=312, train_loss_epoch=314]
Epoch 7/10:  70%|███████   | 7/10 [00:00<00:00, 63.07it/s, v_num=1, train_loss_step=311, train_loss_epoch=312]
Epoch 8/10:  70%|███████   | 7/10 [00:00<00:00, 63.07it/s, v_num=1, train_loss_step=311, train_loss_epoch=312]
Epoch 8/10:  80%|████████  | 8/10 [00:00<00:00, 63.07it/s, v_num=1, train_loss_step=311, train_loss_epoch=311]
Epoch 9/10:  80%|████████  | 8/10 [00:00<00:00, 63.07it/s, v_num=1, train_loss_step=311, train_loss_epoch=311]
Epoch 9/10:  90%|█████████ | 9/10 [00:00<00:00, 63.07it/s, v_num=1, train_loss_step=310, train_loss_epoch=310]
Epoch 10/10:  90%|█████████ | 9/10 [00:00<00:00, 63.07it/s, v_num=1, train_loss_step=310, train_loss_epoch=310]
Epoch 10/10: 100%|██████████| 10/10 [00:00<00:00, 63.07it/s, v_num=1, train_loss_step=309, train_loss_epoch=309]
Epoch 10/10: 100%|██████████| 10/10 [00:00<00:00, 62.74it/s, v_num=1, train_loss_step=309, train_loss_epoch=309]
array([[-0.01243285,  0.2494134 ,  1.2999499 , ...,  0.16254096,
         0.14470463,  0.5902857 ],
       [-0.04466929,  0.06925905,  0.3897732 , ...,  0.7805659 ,
         0.17922966,  0.4863392 ],
       [ 0.15143889,  0.62429947,  0.18776797, ...,  0.30747014,
         0.42057133,  0.7097493 ],
       ...,
       [ 0.08512358,  0.4220351 ,  0.7521035 , ...,  0.5072541 ,
         0.27931175,  0.7883702 ],
       [ 0.34226996,  0.06517565,  0.37391824, ...,  0.9882941 ,
         0.16113117,  0.38623482],
       [ 0.1523135 ,  0.16174033,  1.1513405 , ...,  0.6617261 ,
         0.6510729 ,  0.5568458 ]], dtype=float32)

Summary#

We learned the structure of the high-level model classes in scvi-tools, and learned how a simple version of SCVI is implemented.

Questions? Comments? Keep the discussion going on our forum