styled_sr: fix bug when using initial_stride

This commit is contained in:
James Betker 2021-01-01 12:13:21 -07:00
parent 913fc3b75e
commit f39179e85a

View File

@ -94,12 +94,14 @@ class Generator(nn.Module):
def forward(self, lr, styles):
b, c, h, w = lr.shape
input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device)
rgb = lr
styles = styles.transpose(0, 1)
x = self.encoder(lr)
styles = styles.transpose(0, 1)
input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device)
if h != x.shape[-2]:
rgb = F.interpolate(lr, size=x.shape[2:], mode="area")
else:
rgb = lr
for style, block in zip(styles, self.blocks):
x, rgb = checkpoint(block, x, rgb, style, input_noise)
@ -153,8 +155,8 @@ class StyledSrGenerator(nn.Module):
if __name__ == '__main__':
gen = StyledSrGenerator(128)
out = gen(torch.rand(1,3,32,32))
gen = StyledSrGenerator(128, 2)
out = gen(torch.rand(1,3,64,64))
print([o.shape for o in out])