diff --git a/codes/models/styled_sr/styled_sr.py b/codes/models/styled_sr/styled_sr.py index c40bb00b..07447dd7 100644 --- a/codes/models/styled_sr/styled_sr.py +++ b/codes/models/styled_sr/styled_sr.py @@ -130,9 +130,11 @@ class StyledSrGenerator(nn.Module): # Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later. self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False) self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode) + self.l2 = nn.MSELoss() self.mixed_prob = .9 self._init_weights() self.transfer_mode = transfer_mode + self.initial_stride = initial_stride if transfer_mode: for p in self.parameters(): if not hasattr(p, 'FOR_TRANSFER_LEARNING'): @@ -174,11 +176,14 @@ class StyledSrGenerator(nn.Module): out = self.gen(x, w_styles) - # Compute the net, areal, pixel-wise additions made on top of the LR image. - out_down = F.interpolate(out, size=(x.shape[-2], x.shape[-1]), mode="area") - diff = torch.sum(torch.abs(out_down - x), dim=[1,2,3]) + # Compute an L2 loss on the areal interpolation of the generated image back down to LR * initial_stride; used + # for regularization. + out_down = F.interpolate(out, size=(x.shape[-2] // self.initial_stride, x.shape[-1] // self.initial_stride), mode="area") + if self.initial_stride > 1: + x = F.interpolate(x, scale_factor=1/self.initial_stride, mode="area") + l2_reg = self.l2(x, out_down) - return out, diff, w_styles + return out, l2_reg, w_styles if __name__ == '__main__': diff --git a/codes/models/styled_sr/stylegan2_base.py b/codes/models/styled_sr/stylegan2_base.py index 2738725a..dff9be14 100644 --- a/codes/models/styled_sr/stylegan2_base.py +++ b/codes/models/styled_sr/stylegan2_base.py @@ -396,8 +396,8 @@ class GeneratorBlock(nn.Module): x = self.upsample(x) inoise = inoise[:, :x.shape[2], :x.shape[3], :] - noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1)) - noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1)) + noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2)) + noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2)) style1 = self.to_style1(istyle) x = self.conv1(x, style1)