scvi.dataloaders.SemiSupervisedDataSplitter.on_before_batch_transfer

SemiSupervisedDataSplitter.on_before_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

Warning

dataloader_idx always returns 0, and will be updated to support the true index in the future.

Note

This hook only runs on single GPU training and DDP (no data-parallel). Data-Parallel support will come in near future.

Parameters
batch : AnyAny

A batch of data that needs to be altered or augmented.

dataloader_idx : intint

DataLoader idx for batch

Return type

AnyAny

Returns

A batch of data

Example:

def on_before_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = transforms(batch['x'])
    return batch
Raises

MisconfigurationException – If using data-parallel, Trainer(accelerator='dp').