forked from mrq/DL-Art-School
Fix up pixgan loss and pixdisc
This commit is contained in:
parent
26a4a66d1c
commit
b2507be13c
|
@ -271,7 +271,7 @@ class SRGANModel(BaseModel):
|
||||||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
||||||
|
|
||||||
if self.l_gan_w > 0:
|
if self.l_gan_w > 0:
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan':
|
||||||
pred_g_fake = self.netD(fake_GenOut)
|
pred_g_fake = self.netD(fake_GenOut)
|
||||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||||
elif self.opt['train']['gan_type'] == 'ragan':
|
elif self.opt['train']['gan_type'] == 'ragan':
|
||||||
|
@ -344,8 +344,8 @@ class SRGANModel(BaseModel):
|
||||||
PIXDISC_OUTPUT_REDUCTION = 8
|
PIXDISC_OUTPUT_REDUCTION = 8
|
||||||
PIXDISC_MAX_REDUCTION = 32
|
PIXDISC_MAX_REDUCTION = 32
|
||||||
disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION)
|
disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION)
|
||||||
real = torch.ones(disc_output_shape)
|
real = torch.ones(disc_output_shape, device=var_ref[0].device)
|
||||||
fake = torch.zeros(disc_output_shape)
|
fake = torch.zeros(disc_output_shape, device=var_ref[0].device)
|
||||||
|
|
||||||
# randomly determine portions of the image to swap to keep the discriminator honest.
|
# randomly determine portions of the image to swap to keep the discriminator honest.
|
||||||
if random.random() > .25:
|
if random.random() > .25:
|
||||||
|
@ -353,16 +353,16 @@ class SRGANModel(BaseModel):
|
||||||
# Make the swap across fake_H and var_ref
|
# Make the swap across fake_H and var_ref
|
||||||
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - 1
|
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - 1
|
||||||
assert SWAP_MAX_DIM > 0
|
assert SWAP_MAX_DIM > 0
|
||||||
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
|
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION
|
||||||
swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
|
swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
|
||||||
t = fake_H[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h].clone()
|
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
|
||||||
fake_H[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = var_ref[0][:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h]
|
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
|
||||||
var_ref[0][:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = t
|
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
|
||||||
|
|
||||||
# Swap the expectation matrix too.
|
# Swap the expectation matrix too.
|
||||||
swap_x, swap_y, swap_w, swap_h = swap_x // PIXDISC_OUTPUT_REDUCTION, swap_y // PIXDISC_OUTPUT_REDUCTION, swap_w // PIXDISC_OUTPUT_REDUCTION, swap_h // PIXDISC_OUTPUT_REDUCTION
|
swap_x, swap_y, swap_w, swap_h = swap_x // PIXDISC_OUTPUT_REDUCTION, swap_y // PIXDISC_OUTPUT_REDUCTION, swap_w // PIXDISC_OUTPUT_REDUCTION, swap_h // PIXDISC_OUTPUT_REDUCTION
|
||||||
real[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = 0.0
|
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
|
||||||
fake[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = 1.0
|
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
|
||||||
|
|
||||||
# We're also assuming that this is exactly how the flattened discriminator output is generated.
|
# We're also assuming that this is exactly how the flattened discriminator output is generated.
|
||||||
real = real.view(-1, 1)
|
real = real.view(-1, 1)
|
||||||
|
|
|
@ -161,7 +161,8 @@ class Discriminator_VGG_PixLoss(nn.Module):
|
||||||
dec2 = torch.cat([dec2, fea2], dim=1)
|
dec2 = torch.cat([dec2, fea2], dim=1)
|
||||||
dec2 = self.up2_converge(dec2)
|
dec2 = self.up2_converge(dec2)
|
||||||
dec2 = self.up2_proc(dec2)
|
dec2 = self.up2_proc(dec2)
|
||||||
loss2 = self.up2_reduce(dec2)
|
dec2 = self.up2_reduce(dec2)
|
||||||
|
loss2 = self.up2_pix(dec2)
|
||||||
|
|
||||||
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
|
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
|
||||||
# then know how to handle them.
|
# then know how to handle them.
|
||||||
|
|
|
@ -46,7 +46,7 @@ class GANLoss(nn.Module):
|
||||||
return torch.empty_like(input).fill_(self.fake_label_val)
|
return torch.empty_like(input).fill_(self.fake_label_val)
|
||||||
|
|
||||||
def forward(self, input, target_is_real):
|
def forward(self, input, target_is_real):
|
||||||
if self.gan_type == 'pixgan':
|
if self.gan_type == 'pixgan' and not isinstance(target_is_real, bool):
|
||||||
target_label = target_is_real
|
target_label = target_is_real
|
||||||
else:
|
else:
|
||||||
target_label = self.get_target_label(input, target_is_real)
|
target_label = self.get_target_label(input, target_is_real)
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_rrdb_pixgan_normal_gan.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_rrdb.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user