Fix pixdisc bugs

This commit is contained in:
James Betker 2020-07-31 16:38:14 -06:00
parent eb11a08d1c
commit bcebed19b7

View File

@ -455,13 +455,13 @@ class SRGANModel(BaseModel):
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
if 'pixgan' in self.opt['train']['gan_type']: if 'pixgan' in self.opt['train']['gan_type']:
if not self.disjoint_data:
# randomly determine portions of the image to swap to keep the discriminator honest.
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
b, _, w, h = var_ref.shape b, _, w, h = var_ref.shape
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device) real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device) fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device)
if not self.disjoint_data:
# randomly determine portions of the image to swap to keep the discriminator honest.
SWAP_MAX_DIM = w // 4 SWAP_MAX_DIM = w // 4
SWAP_MIN_DIM = 16 SWAP_MIN_DIM = 16
assert SWAP_MAX_DIM > 0 assert SWAP_MAX_DIM > 0