Source code for scvi.dataset.anndataset

import logging
import operator
import os
from functools import reduce

import anndata
import numpy as np
import pandas as pd
import scipy.sparse as sp_sparse
from typing import Dict, Optional

from scvi.dataset.dataset import (
    DownloadableDataset,
    GeneExpressionDataset,
    CellMeasurement,
)

logger = logging.getLogger(__name__)


[docs]class AnnDatasetFromAnnData(GeneExpressionDataset): """Forms a ``GeneExpressionDataset`` from a ``anndata.AnnData`` object. Parameters ---------- ad ``anndata.AnnData`` instance. batch_label AnnData obs column name for batches ctype_label AnnData obs column name for cell_types class_label AnnData obs column name for labels use_raw if True, copies data from .raw attribute of AnnData """ def __init__( self, ad: anndata.AnnData, batch_label: str = "batch_indices", ctype_label: str = "cell_types", class_label: str = "labels", use_raw: bool = False, cell_measurements_col_mappings: Optional[Dict[str, str]] = None, ): super().__init__() ( X, batch_indices, labels, gene_names, cell_types, obs, obsm, var, _, uns, ) = extract_data_from_anndata( ad, batch_label=batch_label, ctype_label=ctype_label, class_label=class_label, use_raw=use_raw, ) # Dataset API takes a dict as input obs = obs.to_dict(orient="list") var = var.to_dict(orient="list") # add external cell measurements Ys = [] if cell_measurements_col_mappings is not None: for name, attr_name in cell_measurements_col_mappings.items(): columns = uns[attr_name] measurement = CellMeasurement( name=name, data=obsm[name], columns_attr_name=attr_name, columns=columns, ) Ys.append(measurement) self.populate_from_data( X=X, Ys=Ys, labels=labels, batch_indices=batch_indices, gene_names=gene_names, cell_types=cell_types, cell_attributes_dict=obs, gene_attributes_dict=var, ) self.filter_cells_by_count()
[docs]class DownloadableAnnDataset(DownloadableDataset): """Forms a ``DownloadableDataset`` from a `.h5ad` file using the ``anndata`` package. Parameters ---------- filename Name of the `.h5ad` file to save/load. save_path Location to use when saving/loading the data. url URL pointing to the data which will be downloaded if it's not already in ``save_path``. delayed_populating Switch for delayed populating mechanism. batch_label AnnData obs column name for batches ctype_label AnnData obs column name for cell_types class_label AnnData obs column name for labels use_raw if True, copies data from .raw attribute of AnnData Examples -------- >>> # Loading a local dataset >>> dataset = DownloadableAnnDataset("TM_droplet_mat.h5ad", save_path = 'data/') """ def __init__( self, filename: str = "anndataset", save_path: str = "data/", url: str = None, delayed_populating: bool = False, batch_label: str = "batch_indices", ctype_label: str = "cell_types", class_label: str = "labels", use_raw: bool = False, cell_measurements_col_mappings: Optional[Dict[str, str]] = None, ): self.batch_label = batch_label self.ctype_label = ctype_label self.class_label = class_label self.use_raw = use_raw self.cell_measurements_col_mappings_temp = cell_measurements_col_mappings super().__init__( urls=url, filenames=filename, save_path=save_path, delayed_populating=delayed_populating, )
[docs] def populate(self): ad = anndata.read_h5ad( os.path.join(self.save_path, self.filenames[0]) ) # obs = cells, var = genes # extract GeneExpressionDataset relevant attributes # and provide access to annotations from the underlying AnnData object. ( X, batch_indices, labels, gene_names, cell_types, obs, obsm, var, _, uns, ) = extract_data_from_anndata( ad, batch_label=self.batch_label, ctype_label=self.ctype_label, class_label=self.class_label, use_raw=self.use_raw, ) # Dataset API takes a dict as input obs = obs.to_dict(orient="list") var = var.to_dict(orient="list") # add external cell measurements Ys = [] if self.cell_measurements_col_mappings_temp is not None: for name, attr_name in self.cell_measurements_col_mappings_temp.items(): columns = uns[attr_name] measurement = CellMeasurement( name=name, data=obsm[name], columns_attr_name=attr_name, columns=columns, ) Ys.append(measurement) self.populate_from_data( X=X, Ys=Ys, labels=labels, batch_indices=batch_indices, gene_names=gene_names, cell_types=cell_types, cell_attributes_dict=obs, gene_attributes_dict=var, ) self.filter_cells_by_count() del self.cell_measurements_col_mappings_temp
def extract_data_from_anndata( ad: anndata.AnnData, batch_label: str = "batch_indices", ctype_label: str = "cell_types", class_label: str = "labels", use_raw: bool = False, ): data, labels, batch_indices, gene_names, cell_types = None, None, None, None, None # We use obs that will contain all the observation except those associated with # batch_label, ctype_label and class_label. obs = ad.obs.copy() if use_raw: counts = ad.raw.X else: counts = ad.X # treat all possible cases according to anndata doc if isinstance(counts, np.ndarray): data = counts.copy() if isinstance(counts, pd.DataFrame): data = counts.values.copy() if sp_sparse.issparse(counts): # keep sparsity above 1 Gb in dense form if reduce(operator.mul, counts.shape) * counts.dtype.itemsize < 1e9: logger.info("Dense size under 1Gb, casting to dense format (np.ndarray).") data = counts.toarray() else: data = counts.copy() gene_names = np.asarray(ad.var.index.values, dtype=str) if batch_label in obs.columns: batch_indices = obs.pop(batch_label).values if ctype_label in obs.columns: cell_types = obs.pop(ctype_label) res = pd.factorize(cell_types) labels = res[0].astype(int) cell_types = np.array(res[1]).astype(str) elif class_label in obs.columns: labels = obs.pop(class_label) return ( data, batch_indices, labels, gene_names, cell_types, obs, ad.obsm, ad.var, ad.varm, ad.uns, )