Fix up pixgan loss and pixdisc

This commit is contained in:
James Betker 2020-07-08 21:27:48 -06:00
parent 26a4a66d1c
commit b2507be13c
4 changed files with 13 additions and 12 deletions

View File

@ -271,7 +271,7 @@ class SRGANModel(BaseModel):
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay) self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
if self.l_gan_w > 0: if self.l_gan_w > 0:
if self.opt['train']['gan_type'] == 'gan': if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan':
pred_g_fake = self.netD(fake_GenOut) pred_g_fake = self.netD(fake_GenOut)
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
@ -344,8 +344,8 @@ class SRGANModel(BaseModel):
PIXDISC_OUTPUT_REDUCTION = 8 PIXDISC_OUTPUT_REDUCTION = 8
PIXDISC_MAX_REDUCTION = 32 PIXDISC_MAX_REDUCTION = 32
disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION) disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION)
real = torch.ones(disc_output_shape) real = torch.ones(disc_output_shape, device=var_ref[0].device)
fake = torch.zeros(disc_output_shape) fake = torch.zeros(disc_output_shape, device=var_ref[0].device)
# randomly determine portions of the image to swap to keep the discriminator honest. # randomly determine portions of the image to swap to keep the discriminator honest.
if random.random() > .25: if random.random() > .25:
@ -353,16 +353,16 @@ class SRGANModel(BaseModel):
# Make the swap across fake_H and var_ref # Make the swap across fake_H and var_ref
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - 1 SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - 1
assert SWAP_MAX_DIM > 0 assert SWAP_MAX_DIM > 0
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION swap_x, swap_y = random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION
swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
t = fake_H[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h].clone() t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
fake_H[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = var_ref[0][:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
var_ref[0][:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = t var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
# Swap the expectation matrix too. # Swap the expectation matrix too.
swap_x, swap_y, swap_w, swap_h = swap_x // PIXDISC_OUTPUT_REDUCTION, swap_y // PIXDISC_OUTPUT_REDUCTION, swap_w // PIXDISC_OUTPUT_REDUCTION, swap_h // PIXDISC_OUTPUT_REDUCTION swap_x, swap_y, swap_w, swap_h = swap_x // PIXDISC_OUTPUT_REDUCTION, swap_y // PIXDISC_OUTPUT_REDUCTION, swap_w // PIXDISC_OUTPUT_REDUCTION, swap_h // PIXDISC_OUTPUT_REDUCTION
real[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = 0.0 real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
fake[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = 1.0 fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
# We're also assuming that this is exactly how the flattened discriminator output is generated. # We're also assuming that this is exactly how the flattened discriminator output is generated.
real = real.view(-1, 1) real = real.view(-1, 1)

View File

@ -161,7 +161,8 @@ class Discriminator_VGG_PixLoss(nn.Module):
dec2 = torch.cat([dec2, fea2], dim=1) dec2 = torch.cat([dec2, fea2], dim=1)
dec2 = self.up2_converge(dec2) dec2 = self.up2_converge(dec2)
dec2 = self.up2_proc(dec2) dec2 = self.up2_proc(dec2)
loss2 = self.up2_reduce(dec2) dec2 = self.up2_reduce(dec2)
loss2 = self.up2_pix(dec2)
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will # Compress all of the loss values into the batch dimension. The actual loss attached to this output will
# then know how to handle them. # then know how to handle them.

View File

@ -46,7 +46,7 @@ class GANLoss(nn.Module):
return torch.empty_like(input).fill_(self.fake_label_val) return torch.empty_like(input).fill_(self.fake_label_val)
def forward(self, input, target_is_real): def forward(self, input, target_is_real):
if self.gan_type == 'pixgan': if self.gan_type == 'pixgan' and not isinstance(target_is_real, bool):
target_label = target_is_real target_label = target_is_real
else: else:
target_label = self.get_target_label(input, target_is_real) target_label = self.get_target_label(input, target_is_real)

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_rrdb_pixgan_normal_gan.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_rrdb.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)