Source code for torch_template.dataloader.tta

# encoding=utf-8
"""
    TTA plugin used in test data_loader loop, containing overlap and data_aug (8×)

    Author: zks@tju.edu.cn

    Refactor: xuhaoyu@tju.edu.cn
"""

import pdb

import torch
from torch.autograd import Variable
from torchvision.transforms import transforms


[docs]class OverlapTTA(object): """overlap TTA Args: nw(int): num of patches (in width direction) nh(int): num of patches (in height direction) patch_w(int): width of a patch. patch_h(int): height of a patch. norm_patch(bool): if norm each patch or not. flip_aug(bool): not used yet. device(str): device string, default 'cuda:0'. Usage Example >>> from torch_template import OverlapTTA >>> for i, data in enumerate(dataset): >>> tta = OverlapTTA(img, 10, 10, 256, 256, norm_patch=False, flip_aug=False, device=opt.device) >>> for j, x in enumerate(tta): # 获取每个patch输入 >>> generated = model(x) >>> torch.cuda.empty_cache() >>> tta.collect(generated[0], j) # 收集inference结果 >>> output = tta.combine() """ def __init__(self, img, nw, nh, patch_w=256, patch_h=256, norm_patch=False, flip_aug=False, device='cuda:0'): self.img = img self.nw = nw self.nh = nh self.patch_w = patch_w self.patch_h = patch_h self.N, self.C, self.H, self.W = img.shape self.norm_patch = norm_patch self.flip_aug = flip_aug self.transforms = transforms.Compose([transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) if norm_patch else None self.device = device ##################################### # 步长 ##################################### stride_h = (self.H - 256) // (nh - 1) stride_w = (self.W - 256) // (nw - 1) self.overlap_times = torch.zeros((self.C, self.H, self.W)).cpu() self.slice_h = [] self.slice_w = [] ##################################### # 除了最后一个patch, 都按照固定步长取块 # 将位置信息先保存在slice_h和slice_w数组中 ##################################### for i in range(nh - 1): self.slice_h.append([i * stride_h, i * stride_h + 256]) self.slice_h.append([self.H - 256, self.H]) for i in range(nw - 1): self.slice_w.append([i * stride_w, i * stride_w + 256]) self.slice_w.append([self.W - 256, self.W]) ##################################### # 保存结果的数组 ##################################### self.result = torch.zeros((self.C, self.H, self.W)).cpu() def collect(self, x, cur): x = x.detach().cpu() j = cur % self.nw i = (cur - j) // self.nh ##################################### # 分别记录图像和重复次数 ##################################### self.result[:, self.slice_h[i][0]:self.slice_h[i][1], self.slice_w[j][0]:self.slice_w[j][1]] += x self.overlap_times[:, self.slice_h[i][0]:self.slice_h[i][1], self.slice_w[j][0]:self.slice_w[j][1]] += 1 def combine(self): if self.flip_aug: pass else: return self.result / self.overlap_times def __getitem__(self, index): """ 获取tta patch作为网络输入 :param index: :return: """ if self.flip_aug: pass else: j = index % self.nh i = index // self.nw img = self.img[:, :, self.slice_h[i][0]:self.slice_h[i][1], self.slice_w[j][0]:self.slice_w[j][1]] if self.norm_patch: img = self.transforms(img[0]).unsqueeze(dim=0) img_var = Variable(img, requires_grad=False).cuda(device=self.device) return img_var def __len__(self): return self.nw * self.nh