Styled SR fixes

This commit is contained in:
James Betker 2021-01-05 20:14:39 -07:00
parent 9fed90393f
commit 2f2f87bbea
2 changed files with 11 additions and 6 deletions

View File

@ -130,9 +130,11 @@ class StyledSrGenerator(nn.Module):
# Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later. # Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later.
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False) self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False)
self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode) self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode)
self.l2 = nn.MSELoss()
self.mixed_prob = .9 self.mixed_prob = .9
self._init_weights() self._init_weights()
self.transfer_mode = transfer_mode self.transfer_mode = transfer_mode
self.initial_stride = initial_stride
if transfer_mode: if transfer_mode:
for p in self.parameters(): for p in self.parameters():
if not hasattr(p, 'FOR_TRANSFER_LEARNING'): if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
@ -174,11 +176,14 @@ class StyledSrGenerator(nn.Module):
out = self.gen(x, w_styles) out = self.gen(x, w_styles)
# Compute the net, areal, pixel-wise additions made on top of the LR image. # Compute an L2 loss on the areal interpolation of the generated image back down to LR * initial_stride; used
out_down = F.interpolate(out, size=(x.shape[-2], x.shape[-1]), mode="area") # for regularization.
diff = torch.sum(torch.abs(out_down - x), dim=[1,2,3]) out_down = F.interpolate(out, size=(x.shape[-2] // self.initial_stride, x.shape[-1] // self.initial_stride), mode="area")
if self.initial_stride > 1:
x = F.interpolate(x, scale_factor=1/self.initial_stride, mode="area")
l2_reg = self.l2(x, out_down)
return out, diff, w_styles return out, l2_reg, w_styles
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -396,8 +396,8 @@ class GeneratorBlock(nn.Module):
x = self.upsample(x) x = self.upsample(x)
inoise = inoise[:, :x.shape[2], :x.shape[3], :] inoise = inoise[:, :x.shape[2], :x.shape[3], :]
noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1)) noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2))
noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1)) noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2))
style1 = self.to_style1(istyle) style1 = self.to_style1(istyle)
x = self.conv1(x, style1) x = self.conv1(x, style1)