forked from mrq/DL-Art-School
Fix up pixgan loss and pixdisc
This commit is contained in:
parent
26a4a66d1c
commit
b2507be13c
|
@ -271,7 +271,7 @@ class SRGANModel(BaseModel):
|
|||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
||||
|
||||
if self.l_gan_w > 0:
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan':
|
||||
pred_g_fake = self.netD(fake_GenOut)
|
||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
|
@ -344,8 +344,8 @@ class SRGANModel(BaseModel):
|
|||
PIXDISC_OUTPUT_REDUCTION = 8
|
||||
PIXDISC_MAX_REDUCTION = 32
|
||||
disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION)
|
||||
real = torch.ones(disc_output_shape)
|
||||
fake = torch.zeros(disc_output_shape)
|
||||
real = torch.ones(disc_output_shape, device=var_ref[0].device)
|
||||
fake = torch.zeros(disc_output_shape, device=var_ref[0].device)
|
||||
|
||||
# randomly determine portions of the image to swap to keep the discriminator honest.
|
||||
if random.random() > .25:
|
||||
|
@ -353,16 +353,16 @@ class SRGANModel(BaseModel):
|
|||
# Make the swap across fake_H and var_ref
|
||||
SWAP_MAX_DIM = var_ref[0].shape[2] // (2 * PIXDISC_MAX_REDUCTION) - 1
|
||||
assert SWAP_MAX_DIM > 0
|
||||
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
|
||||
swap_x, swap_y = random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION, random.randint(0, SWAP_MAX_DIM+1) * PIXDISC_MAX_REDUCTION
|
||||
swap_w, swap_h = random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION, random.randint(1, SWAP_MAX_DIM) * PIXDISC_MAX_REDUCTION
|
||||
t = fake_H[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h].clone()
|
||||
fake_H[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = var_ref[0][:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h]
|
||||
var_ref[0][:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = t
|
||||
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
|
||||
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
|
||||
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
|
||||
|
||||
# Swap the expectation matrix too.
|
||||
swap_x, swap_y, swap_w, swap_h = swap_x // PIXDISC_OUTPUT_REDUCTION, swap_y // PIXDISC_OUTPUT_REDUCTION, swap_w // PIXDISC_OUTPUT_REDUCTION, swap_h // PIXDISC_OUTPUT_REDUCTION
|
||||
real[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = 0.0
|
||||
fake[:, :, swap_x:swap_x+swap_w, swap_y:swap_y+swap_h] = 1.0
|
||||
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
|
||||
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
|
||||
|
||||
# We're also assuming that this is exactly how the flattened discriminator output is generated.
|
||||
real = real.view(-1, 1)
|
||||
|
|
|
@ -161,7 +161,8 @@ class Discriminator_VGG_PixLoss(nn.Module):
|
|||
dec2 = torch.cat([dec2, fea2], dim=1)
|
||||
dec2 = self.up2_converge(dec2)
|
||||
dec2 = self.up2_proc(dec2)
|
||||
loss2 = self.up2_reduce(dec2)
|
||||
dec2 = self.up2_reduce(dec2)
|
||||
loss2 = self.up2_pix(dec2)
|
||||
|
||||
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
|
||||
# then know how to handle them.
|
||||
|
|
|
@ -46,7 +46,7 @@ class GANLoss(nn.Module):
|
|||
return torch.empty_like(input).fill_(self.fake_label_val)
|
||||
|
||||
def forward(self, input, target_is_real):
|
||||
if self.gan_type == 'pixgan':
|
||||
if self.gan_type == 'pixgan' and not isinstance(target_is_real, bool):
|
||||
target_label = target_is_real
|
||||
else:
|
||||
target_label = self.get_target_label(input, target_is_real)
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_rrdb_pixgan_normal_gan.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_pixgan_rrdb.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user