Fix pixdisc bugs
This commit is contained in:
parent
eb11a08d1c
commit
bcebed19b7
|
@ -455,13 +455,13 @@ class SRGANModel(BaseModel):
|
||||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||||
l_d_fake_scaled.backward()
|
l_d_fake_scaled.backward()
|
||||||
if 'pixgan' in self.opt['train']['gan_type']:
|
if 'pixgan' in self.opt['train']['gan_type']:
|
||||||
if not self.disjoint_data:
|
|
||||||
# randomly determine portions of the image to swap to keep the discriminator honest.
|
|
||||||
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
||||||
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
|
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
|
||||||
b, _, w, h = var_ref.shape
|
b, _, w, h = var_ref.shape
|
||||||
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
|
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
|
||||||
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device)
|
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device)
|
||||||
|
if not self.disjoint_data:
|
||||||
|
# randomly determine portions of the image to swap to keep the discriminator honest.
|
||||||
SWAP_MAX_DIM = w // 4
|
SWAP_MAX_DIM = w // 4
|
||||||
SWAP_MIN_DIM = 16
|
SWAP_MIN_DIM = 16
|
||||||
assert SWAP_MAX_DIM > 0
|
assert SWAP_MAX_DIM > 0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user