Fix up pixgan loss and pixdisc

This commit is contained in:
James Betker 2020-07-08 21:27:48 -06:00
parent 26a4a66d1c
commit b2507be13c
4 changed files with 13 additions and 12 deletions

View File

@ -271,7 +271,7 @@ class SRGANModel(BaseModel):
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
if self.l_gan_w > 0:
if self.opt['train']['gan_type'] == 'gan':
if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan':
pred_g_fake = self.netD(fake_GenOut)
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
elif self.opt['train']['gan_type'] == 'ragan':
@ -344,8 +344,8 @@ class SRGANModel(BaseModel):
PIXDISC_OUTPUT_REDUCTION = 8
PIXDISC_MAX_REDUCTION = 32
disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION)
real = torch.ones(disc_output_shape)
fake = torch.zeros(disc_output_shape)
real = torch.ones(disc_output_shape, device=var_ref[0].device)
fake = torch.zeros(disc_output_shape, device=var_ref[0].device)
# randomly determine portions of the image to swap to keep the discriminator honest.
if random.random() > .25:
@ -353,16 +353,16 @@ class SRGANModel(BaseModel):
# Make the swap across fake_H and var_ref
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - 1
assert SWAP_MAX_DIM > 0
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION
swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
t = fake_H[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h].clone()
fake_H[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = var_ref[0][:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h]
var_ref[0][:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = t
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
# Swap the expectation matrix too.
swap_x, swap_y, swap_w, swap_h = swap_x // PIXDISC_OUTPUT_REDUCTION, swap_y // PIXDISC_OUTPUT_REDUCTION, swap_w // PIXDISC_OUTPUT_REDUCTION, swap_h // PIXDISC_OUTPUT_REDUCTION
real[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = 0.0
fake[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = 1.0
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
# We're also assuming that this is exactly how the flattened discriminator output is generated.
real = real.view(-1, 1)

View File

@ -161,7 +161,8 @@ class Discriminator_VGG_PixLoss(nn.Module):
dec2 = torch.cat([dec2, fea2], dim=1)
dec2 = self.up2_converge(dec2)
dec2 = self.up2_proc(dec2)
loss2 = self.up2_reduce(dec2)
dec2 = self.up2_reduce(dec2)
loss2 = self.up2_pix(dec2)
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
# then know how to handle them.

View File

@ -46,7 +46,7 @@ class GANLoss(nn.Module):
return torch.empty_like(input).fill_(self.fake_label_val)
def forward(self, input, target_is_real):
if self.gan_type == 'pixgan':
if self.gan_type == 'pixgan' and not isinstance(target_is_real, bool):
target_label = target_is_real
else:
target_label = self.get_target_label(input, target_is_real)

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_rrdb_pixgan_normal_gan.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_rrdb.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)