Note
This page was generated from
model_user_guide.ipynb.
Interactive online version:
.
Constructing a high-level model¶
[1]:
%%capture
import sys
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
!pip install --quiet scvi-tools
[2]:
import numpy as np
import scvi
import torch
Global seed set to 0
/usr/local/lib/python3.7/dist-packages/numba/np/ufunc/parallel.py:363: NumbaWarning: The TBB threading layer requires TBB version 2019.5 or later i.e., TBB_INTERFACE_VERSION >= 11005. Found TBB_INTERFACE_VERSION = 9107. The TBB threading layer is disabled.
warnings.warn(problem)
At this point we have covered
Data registration via
scvi.data.setup_anndata
and dataloaders viaAnnDataLoader
Building a probabilistic model by subclassing
BaseModuleClass
In this tutorial, we will cover the highest-level classes in scvi-tools
: the model classes. The main purpose of these classes (e.g., scvi.model.SCVI
) is to wrap the actions of module instantiation, training, and subsequent posterior queries of our module into a convenient interface. These model classes are the fundamental objects driving scientific analysis of data with scvi-tools
. Out of convention, we will refer to these objects as “models” and the lower-level objects presented in
the previous tutorial as “modules”.
A simple model class¶
Here we will walkthrough an example of building the scvi.model.SCVI
class. We will progressively add functionality to the class.
Sketch of BaseModelClass
¶
Let us start by providing a high level overview of BaseModelClass
that we will inherit. Note that this is pseudocode to provide intuition. We see that BaseModelClass
contains some unverisally applicable methods, and some private methods (conventionally starting with _
in Python) that will become useful after training the model.
class MyModel(UnsupervisedTrainingMixin, BaseModelClass)
def __init__(self, adata):
# sets some basic attributes like is_trained_
# record the setup_dict registered in the adata
self.adata = adata
self.scvi_setup_dict_ = adata.uns["_scvi"]
self.summary_stats = self.scvi_setup_dict_["summary_stats"]
def _validate_anndata(self, adata):
# check that anndata is equivalent by comparing
# to the initial setup_dict
def _make_dataloader(adata):
# return a dataloader to iterate over adata
def train(...):
# Universal train method provided by UnsupservisedTrainingMixin
# BaseModelClass does not come with train
# In general train methods are straightforward to compose manually
def save(...):
# universal save method
# saves modules, anndata setup dict, and attributes ending with _
def load(...):
# universal load method
@staticmethod
def setup_anndata(adata, ...)
# anndata registration method
Baseline version of SCVI
class¶
Let’s now create the simplest possible version of the SCVI
class. We inherit the BaseModelClass
, and write our __init__
method.
We take care to do the following:
Set the
module
attribute to be equal to ourVAE
module, which here is the torch-level version of scVI.Add a
_model_summary_string
attr, which will be used as a representation for the model.Run
self.init_params_ = self._get_init_params(locals())
, which stores the arguments used to initialize the model, facilitating saving/loading of the model later.
To initialize the VAE
, we can use the information in self.summary_stats
, which is information that was stored in the anndata object at setup_anndata()
time. In this example, we have only exposed n_latent
to users through SCVI
. In practice, we try to expose only the most relevant parameters, as all other parameters can be accessed by passing model_kwargs
.
Finally, we write the setup_anndata
static method, which is used to register the appropriate matrices inside the anndata that we will use to load data into the model. Here we use a private global method, but in the future releases we will have a more object-oriented approach for setup.
[3]:
from typing import Optional
from anndata import AnnData
from scvi.module import VAE
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin
from scvi.data._anndata import _setup_anndata
class SCVI(UnsupervisedTrainingMixin, BaseModelClass):
"""
single-cell Variational Inference [Lopez18]_.
"""
def __init__(
self,
adata: AnnData,
n_latent: int = 10,
**model_kwargs,
):
super(SCVI, self).__init__(adata)
self.module = VAE(
n_input=self.summary_stats["n_vars"],
n_batch=self.summary_stats["n_batch"],
n_latent=n_latent,
**model_kwargs,
)
self._model_summary_string = (
"SCVI Model with the following params: \nn_latent: {}"
).format(
n_latent,
)
self.init_params_ = self._get_init_params(locals())
@staticmethod
def setup_anndata(
adata: AnnData,
batch_key: Optional[str] = None,
layer: Optional[str] = None,
copy: bool = False,
) -> Optional[AnnData]:
return _setup_anndata(
adata,
batch_key=batch_key,
layer=layer,
copy=copy,
)
Now we explore what we can and cannot do with this model. Let’s get some data and initialize a SCVI
instance. Of note, for testing purposes we like to use scvi.data.synthetic_iid()
which returns a simple, small anndata object that was already run through setup_anndata()
.
[7]:
adata = scvi.data.synthetic_iid(run_setup_anndata=False)
adata
[7]:
AnnData object with n_obs × n_vars = 400 × 100
obs: 'batch', 'labels'
uns: 'protein_names'
obsm: 'protein_expression'
[8]:
SCVI.setup_anndata(adata, batch_key="batch")
model = SCVI(adata)
model
INFO Using batches from adata.obs["batch"]
INFO No label_key inputted, assuming all cells have same label
INFO Using data from adata.X
INFO Successfully registered anndata object containing 400 cells, 100 vars, 2 batches, 1
labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra
continuous covariates.
INFO Please do not further modify adata until model is trained.
SCVI Model with the following params:
n_latent: 10
Training status: Not Trained
To print summary of associated AnnData, use: scvi.data.view_anndata_setup(model.adata)
[8]:
[9]:
model.train(max_epochs=20)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Epoch 20/20: 100%|██████████| 20/20 [00:01<00:00, 15.45it/s, loss=304, v_num=1]
The train
method¶
We were able to train this model, as this method is inherited in the class. Let us now take a look at psedocode of the train
method of UnsupervisedTrainingMixin
. The function of each of these objects is described in the API reference.
def train(
self,
max_epochs: Optional[int] = 100,
use_gpu: Optional[bool] = None,
train_size: float = 0.9,
**kwargs,
):
"""
Train the model.
"""
# object to make train/test/val dataloaders
data_splitter = DataSplitter(
self.adata,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
use_gpu=use_gpu,
)
# defines optimizers, training step, val step, logged metrics
training_plan = TrainingPlan(
self.module, len(data_splitter.train_idx),
)
# creates Trainer, pre and post training procedures (Trainer.fit())
runner = TrainRunner(
self,
training_plan=training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
**kwargs,
)
return runner()
We notice two new things:
A training plan (
training_plan
)A train runner (
runner
)
The TrainRunner
is a lightweight wrapper of the PyTorch lightning’s `Trainer
<https://pytorch-lightning.readthedocs.io/en/stable/trainer.html#trainer-class-api>`__, which is a completely black-box method once a TrainingPlan
is defined. So what does the TrainingPlan
do?
Configures optimizers (e.g., Adam), learning rate schedulers.
Defines the training step, which runs a minibatch of data through the model and records the loss.
Defines the validation step, same as training step, but for validation data.
Records relevant metrics, such as the ELBO.
In scvi-tools
we have scvi.lightning.TrainingPlan
, which should cover many use cases, from VAEs and VI, to MLE and MAP estimation. Developers may find that they need a custom TrainingPlan
for e.g,. multiple optimizers and complex training scheme. These can be written and used by the model class.
Developers may also overwrite this train method to add custom functionality like Early Stopping (see TOTALVI’s train method). In most cases the higher-level train method can call super().train()
, which would be the BaseModelClass
train method.
Save and load¶
We can also save and load this model object, as it follows the expected structure.
[10]:
model.save("saved_model/", save_anndata=True)
model = SCVI.load("saved_model/")
INFO Using data from adata.X
INFO Registered keys:['X', 'batch_indices', 'labels']
INFO Successfully registered anndata object containing 400 cells, 100 vars, 2 batches, 1
labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra
continuous covariates.
/usr/local/lib/python3.7/dist-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function transfer_anndata_setup is deprecated; This method will be removed in 0.15.0. Please avoid building any new dependencies on it.
warnings.warn(msg, category=FutureWarning)
Writing methods to query the model¶
So we have a model that wraps a module that has been trained. How can we get information out of the module and present in cleanly to our users? Let’s implement a simple example: getting the latent representation out of the VAE.
This method has the following structure:
Validate the user-supplied data
Create a data loader
Iterate over the data loader and feed into the VAE, getting the tensor of interest out of the VAE.
[11]:
from typing import Optional, Sequence
@torch.no_grad()
def get_latent_representation(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
) -> np.ndarray:
r"""
Return the latent representation for each cell.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
Returns
-------
latent_representation : np.ndarray
Low-dimensional representation for each cell
"""
if self.is_trained_ is False:
raise RuntimeError("Please train the model first.")
adata = self._validate_anndata(adata)
dataloader = self._make_dataloader(adata=adata, indices=indices, batch_size=batch_size)
latent = []
for tensors in dataloader:
inference_inputs = self.module._get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs)
qz_m = outputs["qz_m"]
latent += [qz_m.cpu()]
return torch.cat(latent).numpy()
Note
Validating the anndata is critical to the user experience. If None
is passed it just returns the anndata used to initialize the model, but if a different object is passed, it checks that this new object is equivalent in structure to the anndata passed to the model. We took great care in engineering this function so as to allow passing anndata objects with potentially missing categories (e.g., model was trained on batches ["A", "B", "C"]
, but the passed anndata only has ["B", "C"]
).
These sorts of checks will ensure that your module will see data that it expects, and the user will get the results they expect without advanced data manipulations.
As a convention, we like to keep the module code as bare as possible and leave all posterior manipulation of module tensors to the model class methods. However, it would have been possible to write a get_z
method in the module, and just have the model class that method.
Mixing in pre-coded features¶
We have a number of Mixin classes that can add functionality to your model through inheritance. Here we demonstrate the `VAEMixin
<https://www.scvi-tools.org/en/stable/api/reference/scvi.model.base.VAEMixin.html#scvi.model.base.VAEMixin>`__ class.
Let’s try to get the latent representation from the object we already created.
[12]:
try:
model.get_latent_representation()
except AttributeError:
print("This function does not exist")
This function does not exist
This method becomes avaialble once the VAEMixin
is inherited. Here’s an overview of the mixin methods, which are coded generally enough that they should be broadly useful to those building VAEs.
class VAEMixin:
@torch.no_grad()
def get_elbo(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
) -> float:
pass
@torch.no_grad()
def get_marginal_ll(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
n_mc_samples: int = 1000,
batch_size: Optional[int] = None,
) -> float:
pass
@torch.no_grad()
def get_reconstruction_error(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
) -> Union[float, Dict[str, float]]:
pass
@torch.no_grad()
def get_latent_representation(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
give_mean: bool = True,
mc_samples: int = 5000,
batch_size: Optional[int] = None,
) -> np.ndarray:
pass
Let’s now inherit the mixin into our SCVI class.
[13]:
from scvi.model.base import VAEMixin, UnsupervisedTrainingMixin
class SCVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
"""
single-cell Variational Inference [Lopez18]_.
"""
def __init__(
self,
adata: AnnData,
n_latent: int = 10,
**model_kwargs,
):
super(SCVI, self).__init__(adata)
self.module = VAE(
n_input=self.summary_stats["n_vars"],
n_batch=self.summary_stats["n_batch"],
n_latent=n_latent,
**model_kwargs,
)
self._model_summary_string = (
"SCVI Model with the following params: \nn_latent: {}"
).format(
n_latent,
)
self.init_params_ = self._get_init_params(locals())
@staticmethod
def setup_anndata(
adata: AnnData,
batch_key: Optional[str] = None,
layer: Optional[str] = None,
copy: bool = False,
) -> Optional[AnnData]:
return _setup_anndata(
adata,
batch_key=batch_key,
layer=layer,
copy=copy,
)
[14]:
model = SCVI(adata)
model.train(10)
model.get_latent_representation()
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Epoch 10/10: 100%|██████████| 10/10 [00:00<00:00, 17.86it/s, loss=312, v_num=1]
[14]:
array([[-0.08544915, -0.25176027, -0.01428087, ..., 0.16004656,
0.04322858, 0.6210681 ],
[-0.207152 , -0.6063088 , -0.1941325 , ..., -0.2693901 ,
0.47606018, 0.16438939],
[-1.2161467 , -0.37276846, -0.7683172 , ..., 0.00299879,
-0.17553166, 0.20497267],
...,
[-0.25556734, -0.41497228, -0.56390095, ..., 0.10808641,
0.23799646, 0.1617701 ],
[-0.20167045, 0.10796846, 0.05922389, ..., 0.01984317,
0.00181232, -0.2722305 ],
[-0.89545447, -0.12291664, -1.0124316 , ..., 0.09532419,
-0.45145237, -0.02742232]], dtype=float32)