Enable disjoint feature networks
This is done by pre-training a feature net that predicts the features of HR images from LR images. Then use the original feature network and this new one in tandem to work only on LR/Gen images.
This commit is contained in:
parent
6e086d0c20
commit
eb11a08d1c
|
@ -15,10 +15,9 @@ class LQGTDataset(data.Dataset):
|
|||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
|
||||
If only GT images are provided, generate LQ images on-the-fly.
|
||||
"""
|
||||
|
||||
def get_lq_path(self, i):
|
||||
which_lq = random.randint(0, len(self.paths_LQ)-1)
|
||||
return self.paths_LQ[which_lq][i]
|
||||
return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])]
|
||||
|
||||
def __init__(self, opt):
|
||||
super(LQGTDataset, self).__init__()
|
||||
|
@ -53,11 +52,6 @@ class LQGTDataset(data.Dataset):
|
|||
print('loaded %i images for use in training GAN only.' % (self.sizes_GAN,))
|
||||
|
||||
assert self.paths_GT, 'Error: GT path is empty.'
|
||||
if self.paths_LQ and self.paths_GT:
|
||||
assert len(self.paths_LQ[0]) == len(
|
||||
self.paths_GT
|
||||
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
|
||||
len(self.paths_LQ[0]), len(self.paths_GT))
|
||||
self.random_scale_list = [1]
|
||||
|
||||
def _init_lmdb(self):
|
||||
|
@ -85,7 +79,7 @@ class LQGTDataset(data.Dataset):
|
|||
GT_size = self.opt['target_size']
|
||||
|
||||
# get GT image
|
||||
GT_path = self.paths_GT[index]
|
||||
GT_path = self.paths_GT[index % len(self.paths_GT)]
|
||||
resolution = [int(s) for s in self.sizes_GT[index].split('_')
|
||||
] if self.data_type == 'lmdb' else None
|
||||
img_GT = util.read_img(self.GT_env, GT_path, resolution)
|
||||
|
|
|
@ -42,6 +42,7 @@ class SRGANModel(BaseModel):
|
|||
else:
|
||||
self.netC = None
|
||||
self.mega_batch_factor = 1
|
||||
self.disjoint_data = False
|
||||
|
||||
# define losses, optimizer and scheduler
|
||||
if self.is_train:
|
||||
|
@ -101,16 +102,28 @@ class SRGANModel(BaseModel):
|
|||
self.cri_fea = None
|
||||
if self.cri_fea: # load VGG perceptual loss
|
||||
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
|
||||
self.lr_netF = None
|
||||
if 'lr_fea_path' in train_opt.keys():
|
||||
self.lr_netF = networks.define_F(opt, use_bn=False, load_path=train_opt['lr_fea_path']).to(self.device)
|
||||
self.disjoint_data = True
|
||||
|
||||
if opt['dist']:
|
||||
pass # do not need to use DistributedDataParallel for netF
|
||||
else:
|
||||
self.netF = DataParallel(self.netF)
|
||||
if self.lr_netF:
|
||||
self.lr_netF = DataParallel(self.lr_netF)
|
||||
|
||||
# You can feed in a list of frozen pre-trained discriminators. These are treated the same as feature losses.
|
||||
self.fixed_disc_nets = []
|
||||
if 'fixed_discriminators' in opt.keys():
|
||||
for opt_fdisc in opt['fixed_discriminators'].keys():
|
||||
self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminators'][opt_fdisc]).to(self.device))
|
||||
netFD = networks.define_fixed_D(opt['fixed_discriminators'][opt_fdisc]).to(self.device)
|
||||
if opt['dist']:
|
||||
pass # do not need to use DistributedDataParallel for netF
|
||||
else:
|
||||
netFD = DataParallel(netFD)
|
||||
self.fixed_disc_nets.append(netFD)
|
||||
|
||||
# GD gan loss
|
||||
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||
|
@ -330,6 +343,9 @@ class SRGANModel(BaseModel):
|
|||
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
|
||||
l_g_total += l_g_fdpl * self.fdpl_weight
|
||||
if self.cri_fea and not using_gan_img: # feature loss
|
||||
if self.lr_netF is not None:
|
||||
real_fea = self.lr_netF(var_L, interpolate_factor=self.opt['scale'])
|
||||
else:
|
||||
real_fea = self.netF(pix).detach()
|
||||
fake_fea = self.netF(fea_GenOut)
|
||||
fea_w = self.l_fea_sched.get_weight_for_step(step)
|
||||
|
@ -346,7 +362,7 @@ class SRGANModel(BaseModel):
|
|||
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
|
||||
# it should target this
|
||||
|
||||
l_g_fix_disc = 0
|
||||
l_g_fix_disc = torch.zeros(1, requires_grad=False).squeeze()
|
||||
for fixed_disc in self.fixed_disc_nets:
|
||||
weight = fixed_disc.fdisc_weight
|
||||
real_fea = fixed_disc(pix).detach()
|
||||
|
@ -439,6 +455,7 @@ class SRGANModel(BaseModel):
|
|||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
if 'pixgan' in self.opt['train']['gan_type']:
|
||||
if not self.disjoint_data:
|
||||
# randomly determine portions of the image to swap to keep the discriminator honest.
|
||||
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
||||
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
|
||||
|
|
|
@ -26,8 +26,10 @@ class VGGFeatureExtractor(nn.Module):
|
|||
for k, v in self.features.named_parameters():
|
||||
v.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
# Assume input range is [0, 1]
|
||||
def forward(self, x, interpolate_factor=1):
|
||||
if interpolate_factor > 1:
|
||||
x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic')
|
||||
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
output = self.features(x)
|
||||
|
|
|
@ -168,7 +168,7 @@ def define_fixed_D(opt):
|
|||
|
||||
|
||||
# Define network used for perceptual loss
|
||||
def define_F(opt, use_bn=False, for_training=False):
|
||||
def define_F(opt, use_bn=False, for_training=False, load_path=None):
|
||||
gpu_ids = opt['gpu_ids']
|
||||
device = torch.device('cuda' if gpu_ids else 'cpu')
|
||||
if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg':
|
||||
|
@ -186,5 +186,21 @@ def define_F(opt, use_bn=False, for_training=False):
|
|||
elif opt['train']['which_model_F'] == 'wide_resnet':
|
||||
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device)
|
||||
|
||||
netF.eval() # No need to train
|
||||
if load_path:
|
||||
# Load the model parameters:
|
||||
load_net = torch.load(load_path)
|
||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||
for k, v in load_net.items():
|
||||
if k.startswith('module.'):
|
||||
load_net_clean[k[7:]] = v
|
||||
else:
|
||||
load_net_clean[k] = v
|
||||
netF.load_state_dict(load_net_clean)
|
||||
|
||||
# Put into eval mode, freeze the parameters and set the 'weight' field.
|
||||
netF.eval()
|
||||
for k, v in netF.named_parameters():
|
||||
v.requires_grad = False
|
||||
netF.fdisc_weight = opt['weight']
|
||||
|
||||
return netF
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg4_lr_feat.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user