From b2507be13c2403b654f42023dcc12c5a23d0cf69 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 8 Jul 2020 21:27:48 -0600 Subject: [PATCH] Fix up pixgan loss and pixdisc --- codes/models/SRGAN_model.py | 18 +++++++++--------- codes/models/archs/discriminator_vgg_arch.py | 3 ++- codes/models/loss.py | 2 +- codes/train.py | 2 +- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 51cc95e0..6de9664c 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -271,7 +271,7 @@ class SRGANModel(BaseModel): self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay) if self.l_gan_w > 0: - if self.opt['train']['gan_type'] == 'gan': + if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan': pred_g_fake = self.netD(fake_GenOut) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': @@ -344,8 +344,8 @@ class SRGANModel(BaseModel): PIXDISC_OUTPUT_REDUCTION = 8 PIXDISC_MAX_REDUCTION = 32 disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION) - real = torch.ones(disc_output_shape) - fake = torch.zeros(disc_output_shape) + real = torch.ones(disc_output_shape, device=var_ref[0].device) + fake = torch.zeros(disc_output_shape, device=var_ref[0].device) # randomly determine portions of the image to swap to keep the discriminator honest. if random.random() > .25: @@ -353,16 +353,16 @@ class SRGANModel(BaseModel): # Make the swap across fake_H and var_ref SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - 1 assert SWAP_MAX_DIM > 0 - swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION + swap_x, swap_y = random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION - t = fake_H[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h].clone() - fake_H[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = var_ref[0][:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] - var_ref[0][:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = t + t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() + fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] + var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t # Swap the expectation matrix too. swap_x, swap_y, swap_w, swap_h = swap_x // PIXDISC_OUTPUT_REDUCTION, swap_y // PIXDISC_OUTPUT_REDUCTION, swap_w // PIXDISC_OUTPUT_REDUCTION, swap_h // PIXDISC_OUTPUT_REDUCTION - real[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = 0.0 - fake[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = 1.0 + real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0 + fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0 # We're also assuming that this is exactly how the flattened discriminator output is generated. real = real.view(-1, 1) diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index f56b380e..61002b59 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -161,7 +161,8 @@ class Discriminator_VGG_PixLoss(nn.Module): dec2 = torch.cat([dec2, fea2], dim=1) dec2 = self.up2_converge(dec2) dec2 = self.up2_proc(dec2) - loss2 = self.up2_reduce(dec2) + dec2 = self.up2_reduce(dec2) + loss2 = self.up2_pix(dec2) # Compress all of the loss values into the batch dimension. The actual loss attached to this output will # then know how to handle them. diff --git a/codes/models/loss.py b/codes/models/loss.py index 0c114073..342aad38 100644 --- a/codes/models/loss.py +++ b/codes/models/loss.py @@ -46,7 +46,7 @@ class GANLoss(nn.Module): return torch.empty_like(input).fill_(self.fake_label_val) def forward(self, input, target_is_real): - if self.gan_type == 'pixgan': + if self.gan_type == 'pixgan' and not isinstance(target_is_real, bool): target_label = target_is_real else: target_label = self.get_target_label(input, target_is_real) diff --git a/codes/train.py b/codes/train.py index 200bbff7..6ad0f763 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_rrdb_pixgan_normal_gan.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_rrdb.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)