forked from mrq/DL-Art-School
Styled SR fixes
This commit is contained in:
parent
9fed90393f
commit
2f2f87bbea
|
@ -130,9 +130,11 @@ class StyledSrGenerator(nn.Module):
|
|||
# Assume the vectorizer doesnt need transfer_mode=True. Re-evaluate this later.
|
||||
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp, transfer_mode=False)
|
||||
self.gen = Generator(image_size=image_size, latent_dim=latent_dim, initial_stride=initial_stride, transfer_mode=transfer_mode)
|
||||
self.l2 = nn.MSELoss()
|
||||
self.mixed_prob = .9
|
||||
self._init_weights()
|
||||
self.transfer_mode = transfer_mode
|
||||
self.initial_stride = initial_stride
|
||||
if transfer_mode:
|
||||
for p in self.parameters():
|
||||
if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
|
||||
|
@ -174,11 +176,14 @@ class StyledSrGenerator(nn.Module):
|
|||
|
||||
out = self.gen(x, w_styles)
|
||||
|
||||
# Compute the net, areal, pixel-wise additions made on top of the LR image.
|
||||
out_down = F.interpolate(out, size=(x.shape[-2], x.shape[-1]), mode="area")
|
||||
diff = torch.sum(torch.abs(out_down - x), dim=[1,2,3])
|
||||
# Compute an L2 loss on the areal interpolation of the generated image back down to LR * initial_stride; used
|
||||
# for regularization.
|
||||
out_down = F.interpolate(out, size=(x.shape[-2] // self.initial_stride, x.shape[-1] // self.initial_stride), mode="area")
|
||||
if self.initial_stride > 1:
|
||||
x = F.interpolate(x, scale_factor=1/self.initial_stride, mode="area")
|
||||
l2_reg = self.l2(x, out_down)
|
||||
|
||||
return out, diff, w_styles
|
||||
return out, l2_reg, w_styles
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -396,8 +396,8 @@ class GeneratorBlock(nn.Module):
|
|||
x = self.upsample(x)
|
||||
|
||||
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
|
||||
noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
|
||||
noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))
|
||||
noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2))
|
||||
noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2))
|
||||
|
||||
style1 = self.to_style1(istyle)
|
||||
x = self.conv1(x, style1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user