Fix pixdisc bugs
This commit is contained in:
parent
eb11a08d1c
commit
bcebed19b7
|
@ -455,13 +455,13 @@ class SRGANModel(BaseModel):
|
|||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
if 'pixgan' in self.opt['train']['gan_type']:
|
||||
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
||||
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
|
||||
b, _, w, h = var_ref.shape
|
||||
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
|
||||
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device)
|
||||
if not self.disjoint_data:
|
||||
# randomly determine portions of the image to swap to keep the discriminator honest.
|
||||
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
||||
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
|
||||
b, _, w, h = var_ref.shape
|
||||
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
|
||||
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device)
|
||||
SWAP_MAX_DIM = w // 4
|
||||
SWAP_MIN_DIM = 16
|
||||
assert SWAP_MAX_DIM > 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user