Note

This page was generated from data_tutorial.ipynb. Interactive online version: . Some tutorial content may look better in light mode.

# Data handling in scvi-tools#

In this tutorial we will cover how data is handled in scvi-tools.

Sections:

1. Introduction to the registry comprised of data_registry, state_registry, and summary_stats.

2. Explanation of AnnDataField classes and how they populate the registry via the AnnDataManager.

3. Data loading with AnnDataLoader() outside of scvi-tools models.

4. Writing a setup_anndata() function for an scvi-tools model.

[ ]:

!pip install --quiet scvi-colab
from scvi_colab import install

install()

[84]:

import numpy as np
import scvi
import torch.nn
from scvi.data import AnnDataManager
from scvi.data.fields import CategoricalObsField, LayerField


## Recording AnnData state with object registration#

Scvi-tools knows what subset of AnnData to load into models during training/inference via a data registration process handled by setup_anndata().

This setup process is orchestrated by an AnnDataManager object which wraps the AnnData object and creates a corresponding registry.

In this section we enumerate the fields in the registry object. The registry takes the form of a nested dictionary and is stored as an instance variable of an AnnDataManager object, adata_manager.registry.

The top level of the registry contains the following keys:

• scvi_version keeps track of the version of scvi-tools used to setup the AnnData Object.

• model_name and setup_args keep track of the model and arguments used to run setup_anndata(). These fields are optional, since the AnnDataManager can also be created outside of a setup_anndata() function.

• field_registries is dictionary which maps registry keys (e.g. batch, labels) to additional field-specific information.

Within each field registry, there the following three keys:

• data_registry contains the location of data to load. This is what is used by the DataLoaders to iterate over the AnnData.

• state_registry contains any state (e.g. categorical mappings for batch) relevant to the field during register_field().

• summary_stats contains summary statistics relevant to the field.

Here we construct an AnnDataManager and create a registry via register_fields. In the next section, we will breakdown how the registry is populated as a function of the AnnDataFields. We can visualize the contents of the registry via the function view_registry().

[86]:

adata = scvi.data.synthetic_iid()

anndata_fields = [
LayerField(registry_key="x", layer=None, is_count_data=True),
CategoricalObsField(registry_key="batch", obs_key="batch"),
]
print(
)  # There is additionally a _scvi_uuid key which is used to uniquely identify AnnData objects for subsequent retrieval.

dict_keys(['scvi_version', 'model_name', 'setup_args', 'field_registries', '_scvi_uuid'])

[87]:

adata_manager.view_registry()

Anndata setup with scvi-tools version 0.15.0a0.

     Summary Statistics
┏━━━━━━━━━━━━━━━━━━┳━━━━━━━┓
┃ Summary Stat Key ┃ Value ┃
┡━━━━━━━━━━━━━━━━━━╇━━━━━━━┩
│     n_cells      │  400  │
│      n_vars      │  100  │
│     n_batch      │   2   │
└──────────────────┴───────┘

               Data Registry
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Registry Key ┃   scvi-tools Location    ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
└──────────────┴──────────────────────────┘

                  batch State Registry
┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃  Source Location   ┃ Categories ┃ scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.obs['batch'] │  batch_0   │          0          │
│                    │  batch_1   │          1          │
└────────────────────┴────────────┴─────────────────────┘


The above summary incorporates all three of the components making up each field registry as mentioned before.

### Data Registry#

First, lets turn our attention to the data_registry.

This is used by the AnnDataLoader during training to

• Access the correct “slots” of AnnData

• Minibatch the data, while optionally densifying sparse-formatted data

Each key of the data_registry is the name of tensor and is used to retreive the data from the dataloader output.

• All the data registered via register_fields() have registry keys associated with them, as defined in the AnnDataField class.

The value of each key in the data_registry is a dictionary with two keys: attr_name and attr_key.

• attr_name is the attribute of adata to load data from eg. obs, obsm, layers.

• attr_key is the key of the attribute to access the data.

For example, based off the following data_registry, batch information is loaded from adata.obs['_scvi_batch'] and will be accessible via batch.

While the data registry dictionary is stored within the registry, the AnnDataManager provides a helper method, adata_manager.data_registry, which coalesces the full data registry across each of the fields. This helper method additionally wraps the dictionary in a custom attrdict class which allows dictionary access via dot notation (e.g. data_registry.batch.attr_name).

[88]:

data_registry = adata_manager.data_registry
data_registry

[88]:

attrdict({'x': attrdict({'attr_name': 'X', 'attr_key': None}), 'batch': attrdict({'attr_name': 'obs', 'attr_key': '_scvi_batch'})})

[89]:

print(data_registry["batch"])
print(data_registry.batch.attr_key)

attrdict({'attr_name': 'obs', 'attr_key': '_scvi_batch'})
_scvi_batch


### State Registries#

During the data registration process, we also keep track of additional information from the registration process, necessary for model initialization or downstream functionality. For example, for the batch field, scvi-tools keeps track of the location of the original data as well as the categorical to integer mappings.

The batch state registry holds the following two keys:

• original_key is the original key passed in by the user to load the data.

• categorical_mapping is the categorical to integer mapping of the data. The index of the category is its corresponding integer representation.

We can access a state registry via the function AnnDataManager.get_state_registry() which takes a registry key.

[90]:

batch_state_registry = adata_manager.get_state_registry("batch")
print(batch_state_registry.keys())

print(f"Categorical mapping: {batch_state_registry.categorical_mapping}")
print(f"Original key: {batch_state_registry.original_key}")

dict_keys(['categorical_mapping', 'original_key'])
Categorical mapping: ['batch_0' 'batch_1']
Original key: batch


### Summary Stats#

Lastly, we have the summary stats dictionary which is a dictionary meant to store summary statistics frequently used in models, to avoid redundancy and for summarization in view_registry(). Like the other two components, the AnnDataManager has a helper method in the form of the property adata_manager.summary_stats.

[91]:

adata_manager.summary_stats

[91]:

attrdict({'n_cells': 400, 'n_vars': 100, 'n_batch': 2})


## AnnDataManager and AnnDataFields#

Now that we have gone over the registered state of an AnnDataManager, we can go over how the underlying logic is organized.

While the AnnDataManager provides the main interface to the data registration components, the logic specific to each field is encapsulated in AnnDataField classes (any child class of BaseAnnDataField).

An AnnDataField class contains four main functions to be implemented:

1. register_field sets up the relevant field on the AnnData object and returns the state registry for this field.

2. validate_field is a function called before register_field. E.g. checks if the data field is present on the AnnData object.

3. transfer_field is a function similar to register_field, but additionally takes a source state_registry which can modify the behavior of registration. E.g. for categorical fields we may want to maintain the source categories and append any additional categories on the target AnnData object for downstream transfer learning.

4. get_summary_stats is a function that takes a state_registry and outputs the summary stat dictionary. Note, this means the summary statistics must be a function of what is stored in state_registry.

Together, the set of AnnDataFields produces the registry detailed in part 1.

[92]:

print(adata_manager.fields)

[<scvi.data.fields._layer_field.LayerField object at 0x15004a9d0>, <scvi.data.fields._obs_field.CategoricalObsField object at 0x15004a7f0>]

[93]:

adata2 = scvi.data.synthetic_iid()
print("Before register_field:")
print()

print("After register_field:")
print()
print(f"State registry: {batch_state_registry}")

Before register_field:
AnnData object with n_obs × n_vars = 400 × 100
obs: 'batch', 'labels'
uns: 'protein_names'
obsm: 'protein_expression'

After register_field:
AnnData object with n_obs × n_vars = 400 × 100
obs: 'batch', 'labels', '_scvi_batch'
uns: 'protein_names'
obsm: 'protein_expression'

State registry: {'categorical_mapping': array(['batch_0', 'batch_1'], dtype=object), 'original_key': 'batch'}


Notice how in this case the batch field’s register_field() function adds an additional .obs field that stores an encoded version of the adata.obs['batch'] column. The categorical mapping array order corresponds to the integer encoding of the category.

Important

Adding data inplace to an AnnData object should be done with care. Most fields do not add any data; however, for categorical data it is faster to encode the categories inplace instead of on-the-fly during data loading.

adata_manager.transfer_fields() can be used to produce a new AnnDataManager for a target AnnData object that follows the same structure as the original AnnData. This can be useful in models that are trained with an AnnData object, then used to make predictions on new query data. Under the hood, this calls the transfer_field() function of each field in the AnnDataManager.

In the example below, we try to transfer the batch field onto a new AnnData object with an extra batch category. In this case, the CategoricalObsField.transfer_field() function is parameterized with a extend_categories kwarg which, when True, will extend the batch category set as necessary. If marked extend_categories=False, transfer_field() will raise an error, which may be desired behavior in cases where we want to ensure that query data does not contain any batch categories missing from the train data.

[95]:

adata3 = scvi.data.synthetic_iid(n_batches=3)
try:
except ValueError as e:
print(
e

Category batch_2 not found in source registry. Cannot transfer setup without extend_categories = True.

[98]:

adata3_manager = adata_manager.transfer_fields(
)
adata3_manager.view_registry()  # batch_2 appended to the batch category mapping

Anndata setup with scvi-tools version 0.15.0a0.

     Summary Statistics
┏━━━━━━━━━━━━━━━━━━┳━━━━━━━┓
┃ Summary Stat Key ┃ Value ┃
┡━━━━━━━━━━━━━━━━━━╇━━━━━━━┩
│     n_cells      │  600  │
│      n_vars      │  100  │
│     n_batch      │   3   │
└──────────────────┴───────┘

               Data Registry
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Registry Key ┃   scvi-tools Location    ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
└──────────────┴──────────────────────────┘

                  batch State Registry
┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃  Source Location   ┃ Categories ┃ scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.obs['batch'] │  batch_0   │          0          │
│                    │  batch_1   │          1          │
│                    │  batch_2   │          2          │
└────────────────────┴────────────┴─────────────────────┘


AnnDataLoader is the base dataloader for scvi-tools. In this section we show how the data registered is loaded by AnnDataLoader.

Parameters of AnnDataLoader:

• adata_manager: AnnDataManager object to load data from.

• shuffle: if True will shuffle the data beforehand.

• indices: can provide a subset of indices to load from (Useful when doing train/test splits).

• data_and_attributes: a dictionary where the key corresponds to its key in the data_registry and the value is the numpy data type. By default, all data is passed to the model as np.float32.

• data_loader_kwargs: additional arguments from torch.utils.data.DataLoader.

First, we construct an AnnDataLoader and get the first batch. Then we will enumerate all the values in the batch. The variable data_batch contains the first batch of data. It is a dictionary whose values are the tensors registered in the previous section via register_fields().

[99]:

# initialize an AnnDataLoader which will iterate over our anndata

# get the first batch of data


Notice that the keys in data_batch are the same as the keys in the data_registry.

[100]:

print("data_batch_keys:")
print(data_batch.keys())

data_batch_keys:
dict_keys(['x', 'batch'])

[101]:

adata_manager.data_registry.keys()

[101]:

dict_keys(['x', 'batch'])


If we look at the labels for the first batch from the data loader, it corresponds to the labels of the first 10 cells of our AnnData.

[102]:

adata.obs["batch"][:10]

[102]:

0    batch_0
1    batch_0
2    batch_0
3    batch_0
4    batch_0
5    batch_0
6    batch_0
7    batch_0
8    batch_0
9    batch_0
Name: batch, dtype: category
Categories (2, object): ['batch_0', 'batch_1']

[103]:

# CategoricalObsField.register_field() automatically encoded the categorical labels as integers
data_batch["batch"]

[103]:

tensor([[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.]])

[104]:

print(data_batch["x"].shape)  # shape is batch_size x n_genes
print(data_batch["batch"].shape)  # shape is batch_size x 1

torch.Size([10, 100])
torch.Size([10, 1])


By default, all the data loaded in scvi-tools is np.float32. If you wish to load as a different datatype, you can pass in a dictionary where the key corresponds to a key in the data registry and the value is the datatype.

In the following snippet, we load some continuous data as np.float64 and integer data as np.long32.

[105]:

adl = AnnDataLoader(adata_manager, shuffle=False, batch_size=10)

# by default data has the dtype np.float32
print(data_batch["x"].dtype)
print(data_batch["batch"].dtype)

torch.float32
torch.float32

[106]:

data_batch.keys()

[106]:

dict_keys(['x', 'batch'])


To specify the datatype of each key, we can use the data_and_attributes parameter of AnnDataLoader. Here we make make X an np.long and our cat_covs an np.float64, but keep everything else as np.float32.

[107]:

# the keys of data_and_attributes should correspond to keys in the data registry
print("Data Registry keys:", data_registry_keys)

Data Registry keys: dict_keys(['x', 'batch'])

[108]:

data_and_attributes = {}
for key in data_registry_keys:
if key == "x":
data_and_attributes[key] = np.int64
else:
data_and_attributes[key] = np.float32
print(data_and_attributes)

{'x': <class 'numpy.int64'>, 'batch': <class 'numpy.float32'>}

[109]:

adl = AnnDataLoader(
)
data_batch = next(tensors for tensors in adl)

# by default data has the dtype np.float32
print(data_batch["x"].dtype)
print(data_batch["batch"].dtype)

torch.int64
torch.float32


Finally, if the data_and_attributes parameter is used, it will only load the keys of the passed in dictionary. For example, if the only key in the dictionary passed in to data_and_attributes is X, the data loader will only load X.

[110]:

data_and_attributes = {"x": float}
)

print(data_batch.keys())

dict_keys(['x'])


Below we demonstrate a toy use case where we can take advantage of the AnnDataLoader to minibatch data from our AnnData object into a model. In this example, we train a simple linear regression model.

Important

The DataLoader will not move data to a device (e.g., GPU). This is the responsibility of the user. Alternatively, frameworks like PyTorch Lightning will do this autmoatically for users.

[111]:

# Initialize synthetic_iid data and register with an AnnDataManager
n_genes, n_labels = 10, 3
anndata_fields = [
LayerField(registry_key="x", layer=None, is_count_data=True),
CategoricalObsField(registry_key="labels", obs_key="labels"),
]

# Regression model
linear_reg_model = torch.nn.Linear(n_genes, 1)

# Define loss and optimize
loss_fn = torch.nn.MSELoss(reduction="sum")

def train(x, labels):
# run the model forward on the data
label_pred = linear_reg_model(x).squeeze(-1)
# calculate the mse loss
loss = loss_fn(label_pred, labels.squeeze())
# backpropagate
loss.backward()
optim.step()
return loss

# drop a minibatch if it has 3 or fewer observations
batch_size=128,
drop_last=3,
shuffle=True,
)

for i in range(5):
loss = train(data["x"], data["labels"])
print(f"[iteration {i + 1}] loss: {loss.item()}")

[iteration 1] loss: 850.5758056640625
[iteration 2] loss: 278.2804870605469
[iteration 3] loss: 257.3394470214844
[iteration 4] loss: 74.33056640625
[iteration 5] loss: 45.399696350097656


## scvi-tools Data Registration#

Scvi-tools models produce an AnnDataManager instance in the setup_anndata() function for the purpose of data registration.

setup_anndata() is used to setup data fields specific to each model.

Here we will go over the parameters of one instance of a setup_anndata() method, scvi.model.SCVI.setup_anndata():

• adata is the input AnnData object.

• layer is the key in adata.layers to use for the input data matrix. By default, this is None and the input data matrix will be pulled from adata.X.

• batch_key is the key in adata.obs for batch information. If this is None, will assume that all the data is the same batch.

• labels_key is the key in adata.obs for label information. If this is None, will assume that all the data has the same label.

• size_factor_key is the key in adata.obs that optionally stores size factors for computing the likelihood. If this is None, the library size is used to compute the size factor.

• categorical_covariate_keys is a list of keys in adata.obs for categorical covariates.

• continuous_covariate_key is a list of keys in adata.obs for continuous covariates.

Under the hood:

• For all categorical data (batch, labels, categorical covariates), scvi will automatically compute a mapping from values to integers. Eg. ['a','b','c','a'] will become [0,1,2,0].

• For data fields registered with scvi.model.SCVI.setup_anndata(), scvi will copy the data to a seperate field in the anndata.

• batch_key is copied to scvi.obs['_scvi_batch'] with its integer encoding

• labels_key is copied to scvi.obs['_scvi_labels'] with its integer encoding

• keys in categorical_covariate_keys are concatenated and saved as a pandas DataFrame and stored in adata.obsm['_scvi_extra_categorical_covs'] with its integer encoding.

• keys in continuous_covariate_keys are concatenated and saved as a pandas DataFrame and stored in adata.obsm['_scvi_extra_continuous_covs']

In the following code, we first format an example AnnData Object to setup for scvi-tools, then call scvi.model.SCVI.setup_anndata() to register all the tensors we want to load to the model during training. For our example AnnData Object, we build off the synthetic_iid() dataset, copy X to a layer, and add continuous and categorical covariates to the AnnData.

[112]:

adata = scvi.data.synthetic_iid()
adata.obs["my_categorical_covariate"] = ["A"] * 200 + ["B"] * 200

AnnData object with n_obs × n_vars = 400 × 100
obs: 'batch', 'labels', 'my_categorical_covariate', 'my_continuous_covariate'
uns: 'protein_names'
obsm: 'protein_expression'
layers: 'raw_counts'

[113]:

scvi.model.SCVI.setup_anndata(
batch_key="batch",
labels_key="labels",
layer="raw_counts",
categorical_covariate_keys=["my_categorical_covariate"],
continuous_covariate_keys=["my_continuous_covariate"],
)


Under the hood, this method creates an AnnDataManager instance and stores it in a model-specific manager store until a model is initialized with the same AnnData object. We can visualize the resulting registry via view_anndata_setup(). This calls view_registry() under the hood.

[114]:

model = scvi.model.SCVI(adata)
model.view_anndata_setup()

Anndata setup with scvi-tools version 0.15.0a0.

Setup via SCVI.setup_anndata with arguments:

{
│   'layer': 'raw_counts',
│   'batch_key': 'batch',
│   'labels_key': 'labels',
│   'size_factor_key': None,
│   'categorical_covariate_keys': ['my_categorical_covariate'],
│   'continuous_covariate_keys': ['my_continuous_covariate']
}

         Summary Statistics
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┓
┃     Summary Stat Key     ┃ Value ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━┩
│         n_cells          │  400  │
│          n_vars          │  100  │
│         n_batch          │   2   │
│         n_labels         │   3   │
│ n_extra_categorical_covs │   1   │
│ n_extra_continuous_covs  │   1   │
└──────────────────────────┴───────┘

                             Data Registry
┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃      Registry Key      ┃            scvi-tools Location             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
└────────────────────────┴────────────────────────────────────────────┘

                  batch State Registry
┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃  Source Location   ┃ Categories ┃ scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.obs['batch'] │  batch_0   │          0          │
│                    │  batch_1   │          1          │
└────────────────────┴────────────┴─────────────────────┘

                  labels State Registry
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃   Source Location   ┃ Categories ┃ scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.obs['labels'] │  label_0   │          0          │
│                     │  label_1   │          1          │
│                     │  label_2   │          2          │
└─────────────────────┴────────────┴─────────────────────┘

                   extra_categorical_covs State Registry
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃            Source Location            ┃ Categories ┃ scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.obs['my_categorical_covariate'] │     A      │          0          │
│                                       │     B      │          1          │
│                                       │            │                     │
└───────────────────────────────────────┴────────────┴─────────────────────┘

  extra_continuous_covs State Registry
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃           Source Location            ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
└──────────────────────────────────────┘


Each model defines a set of appropriate AnnDataFields and orchestrates calls to these functions and stores the resulting registry as an instance variable. As mentioned before, the AnnDataManager is constructed during setup_anndata() and retrieved during model initialization.

Here we have an abbreviated version of a setup_anndata() implementation for a model that only takes a layer kwarg and a batch_key:

@classmethod
def setup_anndata(
cls,
layer: Optional[str] = None,
batch_key: Optional[str] = None,
**kwargs,  # Used when loading a model with a new AnnData object.
):
setup_method_args = cls._get_setup_method_args(
**locals()
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
]

The setup_anndata() function itself is quite simple since any complexity in preprocessing is contained within the AnnDataField functions. By factorizing the preprocessing steps into each subclass, model developers can easily extend and reuse logic across models and fields.