scvi.dataloaders.DataSplitter.on_after_batch_transfer

DataSplitter.on_after_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

Warning

dataloader_idx always returns 0, and will be updated to support the true idx in the future.

Note

This hook only runs on single GPU training and DDP (no data-parallel). Data-Parallel support will come in near future.

Parameters
batch : AnyAny

A batch of data that needs to be altered or augmented.

dataloader_idx : intint

DataLoader idx for batch (Default: 0)

Return type

AnyAny

Returns

A batch of data

Example:

def on_after_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = gpu_transforms(batch['x'])
    return batch
Raises

MisconfigurationException – If using data-parallel, Trainer(accelerator='dp').