scvi.dataloaders.SemiSupervisedDataLoader#

class scvi.dataloaders.SemiSupervisedDataLoader(adata_manager, n_samples_per_label=None, indices=None, shuffle=False, batch_size=128, data_and_attributes=None, drop_last=False, **data_loader_kwargs)[source]#

DataLoader that supports semisupervised training.

Parameters:
  • adata_manager (AnnDataManager) – AnnDataManager object that has been created via setup_anndata.

  • n_samples_per_label (Optional[int]) – Number of subsamples for each label class to sample per epoch. By default, there is no label subsampling.

  • indices (Optional[List[int]]) – The indices of the observations in the adata to load

  • shuffle (bool) – Whether the data should be shuffled

  • batch_size (int) – minibatch size to load each iteration

  • data_and_attributes (Optional[dict]) – Dictionary with keys representing keys in data registry (adata_manager.data_registry) and value equal to desired numpy loading type (later made into torch tensor). If None, defaults to all registered data.

  • data_loader_kwargs – Keyword arguments for DataLoader

  • drop_last (Union[bool, int]) –

Attributes table#

Methods table#

check_worker_number_rationality()

resample_labels()

Resamples the labeled data.

subsample_labels()

Subsamples each label class by taking up to n_samples_per_label samples per class.

Attributes#

multiprocessing_context

SemiSupervisedDataLoader.multiprocessing_context[source]#

dataset

SemiSupervisedDataLoader.dataset: Dataset[T_co]#

batch_size

SemiSupervisedDataLoader.batch_size: Optional[int]#

num_workers

SemiSupervisedDataLoader.num_workers: int#

pin_memory

SemiSupervisedDataLoader.pin_memory: bool#

drop_last

SemiSupervisedDataLoader.drop_last: bool#

timeout

SemiSupervisedDataLoader.timeout: float#

sampler

SemiSupervisedDataLoader.sampler: Union[Sampler, Iterable]#

pin_memory_device

SemiSupervisedDataLoader.pin_memory_device: str#

prefetch_factor

SemiSupervisedDataLoader.prefetch_factor: Optional[int]#

Methods#

check_worker_number_rationality

SemiSupervisedDataLoader.check_worker_number_rationality()[source]#

resample_labels

SemiSupervisedDataLoader.resample_labels()[source]#

Resamples the labeled data.

subsample_labels

SemiSupervisedDataLoader.subsample_labels()[source]#

Subsamples each label class by taking up to n_samples_per_label samples per class.