DataLoader

Datasets

class torch_template.templates.dataloader.dataset.ImageDataset(datadir, crop=None, aug=True, norm=False)[source]

ImageDataset for training.

Parameters:
  • datadir (str) – dataset root path, default input and label dirs are ‘input’ and ‘gt’
  • crop (None, int or tuple) – crop size
  • aug (bool) – data argument (×8)
  • norm (bool) – normalization

Example

train_dataset = ImageDataset(‘train’, crop=256) for i, data in enumerate(train_dataset):

input, label, file_name = data
class torch_template.templates.dataloader.dataset.ImageTestDataset(datadir, norm=False)[source]

ImageDataset for test.

Parameters:
  • datadir (str) – dataset path’
  • norm (bool) – normalization

Example

test_dataset = ImageDataset(‘test’, crop=256) for i, data in enumerate(test_dataset):

input, file_name = data

TTA

TTA plugin used in test data_loader loop, containing overlap and data_aug (8×)

Author: zks@tju.edu.cn

Refactor: xuhaoyu@tju.edu.cn

class torch_template.dataloader.tta.OverlapTTA(img, nw, nh, patch_w=256, patch_h=256, norm_patch=False, flip_aug=False, device='cuda:0')[source]

overlap TTA

Parameters:
  • nw (int) – num of patches (in width direction)
  • nh (int) – num of patches (in height direction)
  • patch_w (int) – width of a patch.
  • patch_h (int) – height of a patch.
  • norm_patch (bool) – if norm each patch or not.
  • flip_aug (bool) – not used yet.
  • device (str) – device string, default ‘cuda:0’.
Usage Example
>>> from torch_template import OverlapTTA
>>> for i, data in enumerate(dataset):
>>>     tta = OverlapTTA(img, 10, 10, 256, 256, norm_patch=False, flip_aug=False, device=opt.device)
>>>     for j, x in enumerate(tta):  # 获取每个patch输入
>>>         generated = model(x)
>>>         torch.cuda.empty_cache()
>>>         tta.collect(generated[0], j)  # 收集inference结果
>>>     output = tta.combine()