import pdb
import numpy as np
import torch
import os
from torch import optim
import torch.nn.functional as F
from torch_template import model_zoo
from torch_template.loss.seg_loss import bce_loss, dice_loss, BCEFocalLoss
from torch_template.network.base_model import BaseModel
from torch_template.network.metrics import ssim, L1_loss
from torch_template.utils.torch_utils import ExponentialMovingAverage, print_network
models = {
'Nested': model_zoo['NestedUNet']()
}
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class Model(BaseModel):
def __init__(self, opt):
super(Model, self).__init__()
self.opt = opt
self.cleaner = models[opt.model].cuda(device=opt.device)
#####################
# Init weights
#####################
self.cleaner.apply(weights_init)
print_network(self.cleaner)
self.g_optimizer = optim.Adam(self.cleaner.parameters(), lr=opt.lr)
# self.d_optimizer = optim.Adam(cleaner.parameters(), lr=opt.lr)
# load networks
if opt.load:
pretrained_path = opt.load
self.load_network(self.cleaner, 'G', opt.which_epoch, pretrained_path)
# if self.training:
# self.load_network(self.discriminitor, 'D', opt.which_epoch, pretrained_path)
self.avg_meters = ExponentialMovingAverage(0.95)
self.save_dir = os.path.join(opt.checkpoint_dir, opt.tag)
def update_G(self, img_var, y):
opt = self.opt
# cleaned = x
cleaned = self.cleaner(img_var)
#########################
# sigmoid
#########################
# cleaned = cleaned.mean(dim=1, keepdim=True)
# y = y.mean(dim=1, keepdim=True)
# f1 = f1_loss(cleaned, y, thresh=160/255)
prediction = F.sigmoid(cleaned)
target = F.sigmoid(y)
#########################
# losses
#########################
bce = bce_loss(prediction, target) * opt.weight_bce
dice = dice_loss(prediction, target) * opt.weight_dice
l1 = L1_loss(prediction, target)
# pdb.set_trace()
loss = bce + dice
# GAN loss
# loss_gen_adv = self.discriminitor.calc_gen_loss(input_fake=cleaned)
self.avg_meters.update({'bce': bce.item(), 'dice': dice.item(), 'l1': l1.item()})
#loss_gen = loss + loss_gen_adv * 1.
self.g_optimizer.zero_grad()
loss.backward()
self.g_optimizer.step()
return cleaned
def update_D(self, x, y):
self.d_optimizer.zero_grad()
# encode
cleaned = self.cleaner(x)
# h_b, n_b = self.gen_b.encode(x_b)
# decode (cross domain)
# D loss
loss_dis = self.discriminitor.calc_dis_loss(input_fake=cleaned, input_real=y)
self.avg_meters.update({'dis': loss_dis})
loss_dis = loss_dis * 1. # weights
loss_dis.backward()
self.d_optimizer.step()
return cleaned
def discriminate(self, input_label, test_image, use_pool=False):
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
if use_pool:
fake_query = self.fake_pool.query(input_concat)
return self.netD.forward(fake_query)
else:
return self.netD.forward(input_concat)
def forward(self, x):
return self.cleaner(x)
def inference(self, x, image=None):
pass
def save(self, which_epoch):
self.save_network(self.cleaner, 'G', which_epoch)
# self.save_network(self.discriminitor, 'D', which_epoch)
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.d_optimizer.param_groups:
param_group['lr'] = lr
for param_group in self.g_optimizer.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr