DataLoader¶
Datasets¶
-
class
torch_template.templates.dataloader.dataset.
ImageDataset
(datadir, crop=None, aug=True, norm=False)[source]¶ ImageDataset for training.
Parameters: - datadir (str) – dataset root path, default input and label dirs are ‘input’ and ‘gt’
- crop (None, int or tuple) – crop size
- aug (bool) – data argument (×8)
- norm (bool) – normalization
Example
train_dataset = ImageDataset(‘train’, crop=256) for i, data in enumerate(train_dataset):
input, label, file_name = data
-
class
torch_template.templates.dataloader.dataset.
ImageTestDataset
(datadir, norm=False)[source]¶ ImageDataset for test.
Parameters: - datadir (str) – dataset path’
- norm (bool) – normalization
Example
test_dataset = ImageDataset(‘test’, crop=256) for i, data in enumerate(test_dataset):
input, file_name = data
TTA¶
TTA plugin used in test data_loader loop, containing overlap and data_aug (8×)
Author: zks@tju.edu.cn
Refactor: xuhaoyu@tju.edu.cn
-
class
torch_template.dataloader.tta.
OverlapTTA
(img, nw, nh, patch_w=256, patch_h=256, norm_patch=False, flip_aug=False, device='cuda:0')[source]¶ overlap TTA
Parameters: - nw (int) – num of patches (in width direction)
- nh (int) – num of patches (in height direction)
- patch_w (int) – width of a patch.
- patch_h (int) – height of a patch.
- norm_patch (bool) – if norm each patch or not.
- flip_aug (bool) – not used yet.
- device (str) – device string, default ‘cuda:0’.
- Usage Example
>>> from torch_template import OverlapTTA >>> for i, data in enumerate(dataset): >>> tta = OverlapTTA(img, 10, 10, 256, 256, norm_patch=False, flip_aug=False, device=opt.device) >>> for j, x in enumerate(tta): # 获取每个patch输入 >>> generated = model(x) >>> torch.cuda.empty_cache() >>> tta.collect(generated[0], j) # 收集inference结果 >>> output = tta.combine()