scvi.dataloaders.SemiSupervisedDataSplitter

class scvi.dataloaders.SemiSupervisedDataSplitter(adata, unlabeled_category, train_size=0.9, validation_size=None, n_samples_per_label=None, use_gpu=False, **kwargs)[source]

Creates data loaders train_set, validation_set, test_set.

If train_size + validation_set < 1 then test_set is non-empty. The ratio between labeled and unlabeled data in adata will be preserved in the train/test/val sets.

Parameters
adata : AnnDataAnnData

AnnData to split into train/test/val sets

unlabeled_category

Category to treat as unlabeled

train_size : floatfloat (default: 0.9)

float, or None (default is 0.9)

validation_size : float | NoneOptional[float] (default: None)

float, or None (default is None)

n_samples_per_label : int | NoneOptional[int] (default: None)

Number of subsamples for each label class to sample per epoch

use_gpu : boolbool (default: False)

Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., ‘cuda:0’), or use CPU (if False).

**kwargs

Keyword args for data loader. If adata has labeled data, data loader class is SemiSupervisedDataLoader, else data loader class is AnnDataLoader.

Examples

>>> adata = scvi.data.synthetic_iid()
>>> unknown_label = 'label_0'
>>> splitter = SemiSupervisedDataSplitter(adata, unknown_label)
>>> splitter.setup()
>>> train_dl = splitter.train_dataloader()

Attributes

dims

A tuple describing the shape of your data.

has_prepared_data

Return bool letting you know if datamodule.prepare_data() has been called or not.

has_setup_fit

Return bool letting you know if datamodule.setup(stage='fit') has been called or not.

has_setup_predict

Return bool letting you know if datamodule.setup(stage='predict') has been called or not.

has_setup_test

Return bool letting you know if datamodule.setup(stage='test') has been called or not.

has_setup_validate

Return bool letting you know if datamodule.setup(stage='validate') has been called or not.

has_teardown_fit

Return bool letting you know if datamodule.teardown(stage='fit') has been called or not.

has_teardown_predict

Return bool letting you know if datamodule.teardown(stage='predict') has been called or not.

has_teardown_test

Return bool letting you know if datamodule.teardown(stage='test') has been called or not.

has_teardown_validate

Return bool letting you know if datamodule.teardown(stage='validate') has been called or not.

name

test_transforms

Optional transforms (or collection of transforms) you can apply to test dataset

train_transforms

Optional transforms (or collection of transforms) you can apply to train dataset

val_transforms

Optional transforms (or collection of transforms) you can apply to validation dataset

Methods

add_argparse_args(parent_parser, **kwargs)

Extends existing argparse by default LightningDataModule attributes.

from_argparse_args(args, **kwargs)

Create an instance from CLI arguments.

from_datasets([train_dataset, val_dataset, …])

Create an instance from torch.utils.data.Dataset.

get_init_arguments_and_types()

Scans the DataModule signature and returns argument names, types and default values.

on_after_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

on_before_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

on_load_checkpoint(checkpoint)

Called by Lightning to restore your model.

on_predict_dataloader()

Called before requesting the predict dataloader.

on_save_checkpoint(checkpoint)

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

on_test_dataloader()

Called before requesting the test dataloader.

on_train_dataloader()

Called before requesting the train dataloader.

on_val_dataloader()

Called before requesting the val dataloader.

predict_dataloader()

Implement one or multiple PyTorch DataLoaders for prediction.

prepare_data()

Use this to download and prepare data.

setup([stage])

Split indices in train/test/val sets.

size([dim])

Return the dimension of each input either as a tuple or list of tuples.

teardown([stage])

Called at the end of fit (train + validate), validate, test, predict, or tune.

test_dataloader()

Implement one or multiple PyTorch DataLoaders for testing.

train_dataloader()

Implement one or more PyTorch DataLoaders for training.

transfer_batch_to_device(batch[, device])

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

val_dataloader()

Implement one or multiple PyTorch DataLoaders for validation.