Crossgan
This commit is contained in:
parent
fd7b6ca0a9
commit
1d5f4f6102
|
@ -159,8 +159,6 @@ class SRGANModel(BaseModel):
|
|||
# GD gan loss
|
||||
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
||||
self.l_gan_w = train_opt['gan_weight']
|
||||
if train_opt['gan_type'] == 'pixgan':
|
||||
self.do_pixgan_swap = True if 'do_pixgan_swap' not in train_opt.keys() else train_opt['do_pixgan_swap']
|
||||
# D_update_ratio and D_init_iters
|
||||
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
||||
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
||||
|
@ -367,7 +365,8 @@ class SRGANModel(BaseModel):
|
|||
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
||||
if self.spsr_enabled and self.branch_pretrain and step < self.branch_init_iters:
|
||||
for k, v in self.netG.named_parameters():
|
||||
v.requires_grad = '_branch_pretrain' in k
|
||||
if v.dtype != torch.int64 and v.dtype != torch.bool:
|
||||
v.requires_grad = '_branch_pretrain' in k
|
||||
else:
|
||||
for p in self.netG.parameters():
|
||||
if p.dtype != torch.int64 and p.dtype != torch.bool:
|
||||
|
@ -472,8 +471,11 @@ class SRGANModel(BaseModel):
|
|||
|
||||
|
||||
if self.l_gan_w > 0:
|
||||
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
|
||||
pred_g_fake = self.netD(fake_GenOut)
|
||||
if self.opt['train']['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
|
||||
if self.opt['train']['gan_type'] == 'crossgan':
|
||||
pred_g_fake = self.netD(fake_GenOut, var_L)
|
||||
else:
|
||||
pred_g_fake = self.netD(fake_GenOut)
|
||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_d_real = self.netD(var_ref).detach()
|
||||
|
@ -485,9 +487,12 @@ class SRGANModel(BaseModel):
|
|||
l_g_total += l_g_gan
|
||||
|
||||
if self.spsr_enabled and self.cri_grad_gan:
|
||||
pred_g_fake_grad = self.netD_grad(fake_H_grad)
|
||||
if self.opt['train']['gan_type'] == 'crossgan':
|
||||
pred_g_fake_grad = self.netD(fake_H_grad, var_L)
|
||||
else:
|
||||
pred_g_fake_grad = self.netD(fake_H_grad)
|
||||
pred_g_fake_grad_branch = self.netD_grad(fake_H_branch)
|
||||
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
|
||||
if self.opt['train']['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan']:
|
||||
l_g_gan_grad = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad, True)
|
||||
l_g_gan_grad_branch = self.l_gan_grad_w * self.cri_grad_gan(pred_g_fake_grad_branch, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
|
@ -527,7 +532,7 @@ class SRGANModel(BaseModel):
|
|||
noise.to(self.device)
|
||||
real_disc_images = []
|
||||
fake_disc_images = []
|
||||
for fake_GenOut, var_LGAN, var_H, var_ref, pix in zip(self.fake_GenOut, self.gan_img, self.var_H, self.var_ref, self.pix):
|
||||
for fake_GenOut, var_LGAN, var_L, var_H, var_ref, pix in zip(self.fake_GenOut, self.gan_img, self.var_L, self.var_H, self.var_ref, self.pix):
|
||||
if random.random() > self.gan_lq_img_use_prob:
|
||||
fake_H = fake_GenOut.clone().detach().requires_grad_(False)
|
||||
else:
|
||||
|
@ -558,8 +563,27 @@ class SRGANModel(BaseModel):
|
|||
_, fea_fake = self.netD(fake_H, output_feature_vector=True)
|
||||
actual_fea = self.netF(fake_H)
|
||||
l_d_fea_fake = self.cri_fea(fea_fake, actual_fea) * disc_fea_scale / self.mega_batch_factor
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
if self.opt['train']['gan_type'] == 'crossgan':
|
||||
# need to forward and backward separately, since batch norm statistics differ
|
||||
# real
|
||||
pred_d_real = self.netD(var_ref, var_L)
|
||||
l_d_real = self.cri_gan(pred_d_real, True)
|
||||
l_d_real_log = l_d_real
|
||||
# fake
|
||||
pred_d_fake = self.netD(fake_H, var_L)
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False)
|
||||
l_d_fake_log = l_d_fake
|
||||
# mismatched
|
||||
mismatched_L = torch.roll(var_L, shifts=1, dims=0)
|
||||
pred_d_real_mismatched = self.netD(var_ref, mismatched_L)
|
||||
pred_d_fake_mismatched = self.netD(fake_H, mismatched_L)
|
||||
l_d_mismatched = (self.cri_gan(pred_d_real_mismatched, False) + self.cri_gan(pred_d_fake_mismatched, False)) / 2
|
||||
|
||||
l_d_total = (l_d_real + l_d_fake + l_d_mismatched) / 3
|
||||
l_d_total = l_d_total / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
|
||||
l_d_total_scaled.backward()
|
||||
elif self.opt['train']['gan_type'] == 'gan':
|
||||
# real
|
||||
pred_d_real = self.netD(var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real, True) / self.mega_batch_factor
|
||||
|
@ -572,13 +596,13 @@ class SRGANModel(BaseModel):
|
|||
l_d_total = (l_d_real + l_d_fake) / 2
|
||||
with amp.scale_loss(l_d_total, self.optimizer_D, loss_id=1) as l_d_total_scaled:
|
||||
l_d_total_scaled.backward()
|
||||
if 'pixgan' in self.opt['train']['gan_type']:
|
||||
elif 'pixgan' in self.opt['train']['gan_type']:
|
||||
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
||||
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
|
||||
b, _, w, h = var_ref.shape
|
||||
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
|
||||
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device)
|
||||
if self.do_pixgan_swap and not self.disjoint_data:
|
||||
if not self.disjoint_data:
|
||||
# randomly determine portions of the image to swap to keep the discriminator honest.
|
||||
SWAP_MAX_DIM = w // 4
|
||||
SWAP_MIN_DIM = 16
|
||||
|
@ -631,7 +655,6 @@ class SRGANModel(BaseModel):
|
|||
pdf = pred_d_fake.detach() + torch.abs(torch.min(pred_d_fake))
|
||||
pdf = pdf / torch.max(pdf)
|
||||
fake_disc_images.append(pdf.view(disc_output_shape))
|
||||
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
pred_d_real = self.netD(var_ref)
|
||||
|
@ -667,6 +690,8 @@ class SRGANModel(BaseModel):
|
|||
if self.opt['train']['gan_type'] == 'gan':
|
||||
l_d_real_grad = self.cri_gan(pred_d_real_grad, True)
|
||||
l_d_fake_grad = (self.cri_gan(pred_d_fake_grad, False) + self.cri_gan(pred_d_fake_grad_branch, False)) / 2
|
||||
elif self.opt['train']['gan_type'] == 'crossgan':
|
||||
assert False
|
||||
elif self.opt['train']['gan_type'] == 'pixgan':
|
||||
real = torch.ones_like(pred_d_real_grad)
|
||||
fake = torch.zeros_like(pred_d_fake_grad)
|
||||
|
@ -732,6 +757,8 @@ class SRGANModel(BaseModel):
|
|||
self.add_log_entry('l_d_fea_real', l_d_fea_real.detach().item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_fake_total', l_d_fake.detach().item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_real_total', l_d_real.detach().item() * self.mega_batch_factor)
|
||||
if self.opt['train']['gan_type'] == 'crossgan':
|
||||
self.add_log_entry('l_d_mismatched', l_d_mismatched.detach().item())
|
||||
if self.spsr_enabled:
|
||||
if self.cri_pix_grad:
|
||||
self.add_log_entry('l_g_pix_grad_branch', l_g_pix_grad.detach().item())
|
||||
|
|
|
@ -140,13 +140,14 @@ class Discriminator_VGG_128_GN(nn.Module):
|
|||
|
||||
class CrossCompareBlock(nn.Module):
|
||||
def __init__(self, nf_in, nf_out):
|
||||
super(CrossCompareBlock, self).__init__()
|
||||
self.conv_hr_merge = ConvGnLelu(nf_in * 2, nf_in, kernel_size=1, bias=False, activation=False, norm=True)
|
||||
self.proc_hr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
self.proc_lr = ConvGnLelu(nf_in, nf_out, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
self.reduce_hr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True)
|
||||
self.reduce_lr = ConvGnLelu(nf_out, nf_out, kernel_size=3, stride=2, bias=False, activation=True, norm=True)
|
||||
|
||||
def forward(self, lr, hr):
|
||||
def forward(self, hr, lr):
|
||||
hr = self.conv_hr_merge(torch.cat([hr, lr], dim=1))
|
||||
hr = self.proc_hr(hr)
|
||||
hr = self.reduce_hr(hr)
|
||||
|
@ -154,7 +155,7 @@ class CrossCompareBlock(nn.Module):
|
|||
lr = self.proc_lr(lr)
|
||||
lr = self.reduce_lr(lr)
|
||||
|
||||
return lr, hr
|
||||
return hr, lr
|
||||
|
||||
|
||||
class CrossCompareDiscriminator(nn.Module):
|
||||
|
@ -177,17 +178,23 @@ class CrossCompareDiscriminator(nn.Module):
|
|||
self.fproc_conv = ConvGnLelu(nf * 8, nf, norm=True, bias=True, activation=True)
|
||||
self.out_conv = ConvGnLelu(nf, 1, norm=False, bias=False, activation=False)
|
||||
|
||||
def forward(self, lr, hr):
|
||||
self.scale = scale * 16
|
||||
|
||||
def forward(self, hr, lr):
|
||||
hr = self.init_conv_hr(hr)
|
||||
hr = self.second_conv(hr)
|
||||
lr = self.init_conv_lr(lr)
|
||||
|
||||
lr, hr = self.cross1(lr, hr)
|
||||
lr, hr = self.cross2(lr, hr)
|
||||
lr, hr = self.cross3(lr, hr)
|
||||
_, hr = self.cross4(lr, hr)
|
||||
hr, lr = self.cross1(hr, lr)
|
||||
hr, lr = self.cross2(hr, lr)
|
||||
hr, lr = self.cross3(hr, lr)
|
||||
hr, _ = self.cross4(hr, lr)
|
||||
|
||||
return self.out_conv(self.fproc_conv(hr))
|
||||
return self.out_conv(self.fproc_conv(hr)).view(-1, 1)
|
||||
|
||||
# Returns tuple of (number_output_channels, scale_of_output_reduction (1/n))
|
||||
def pixgan_parameters(self):
|
||||
return 3, self.scale
|
||||
|
||||
|
||||
class Discriminator_VGG_PixLoss(nn.Module):
|
||||
|
|
|
@ -26,7 +26,7 @@ class GANLoss(nn.Module):
|
|||
self.real_label_val = real_label_val
|
||||
self.fake_label_val = fake_label_val
|
||||
|
||||
if self.gan_type == 'gan' or self.gan_type == 'ragan' or self.gan_type == 'pixgan' or self.gan_type == "pixgan_fea":
|
||||
if self.gan_type in ['gan', 'ragan', 'pixgan', 'pixgan_fea', 'crossgan']:
|
||||
self.loss = nn.BCEWithLogitsLoss()
|
||||
elif self.gan_type == 'lsgan':
|
||||
self.loss = nn.MSELoss()
|
||||
|
@ -40,7 +40,7 @@ class GANLoss(nn.Module):
|
|||
return torch.empty_like(input).fill_(self.fake_label_val)
|
||||
|
||||
def forward(self, input, target_is_real):
|
||||
if 'pixgan' in self.gan_type and not isinstance(target_is_real, bool):
|
||||
if self.gan_type in ['pixgan', 'pixgan_fea', 'crossgan'] and not isinstance(target_is_real, bool):
|
||||
target_label = target_is_real
|
||||
else:
|
||||
target_label = self.get_target_label(input, target_is_real)
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_rrdb_noskip.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user