scvi.dataloaders.SemiSupervisedDataSplitter.transfer_batch_to_device

SemiSupervisedDataSplitter.transfer_batch_to_device(batch, device=None)

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

The data types listed below (and any arbitrary nesting of them) are supported out of the box:

For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, …).

Note

This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing).

Note

This hook only runs on single GPU training and DDP (no data-parallel). Data-Parallel support will come in near future.

Parameters
batch : AnyAny

A batch of data that needs to be transferred to a new device.

device : device | NoneOptional[device] (default: None)

The target device as defined in PyTorch.

Return type

AnyAny

Returns

A reference to the data on the new device.

Example:

def transfer_batch_to_device(self, batch, device):
    if isinstance(batch, CustomBatch):
        # move all tensors in your custom data structure to the device
        batch.samples = batch.samples.to(device)
        batch.targets = batch.targets.to(device)
    else:
        batch = super().transfer_batch_to_device(data, device)
    return batch
Raises

MisconfigurationException – If using data-parallel, Trainer(accelerator='dp').

See also

  • move_data_to_device()

  • apply_to_collection()