forked from mrq/DL-Art-School
styled_sr: fix bug when using initial_stride
This commit is contained in:
parent
913fc3b75e
commit
f39179e85a
|
@ -94,12 +94,14 @@ class Generator(nn.Module):
|
||||||
|
|
||||||
def forward(self, lr, styles):
|
def forward(self, lr, styles):
|
||||||
b, c, h, w = lr.shape
|
b, c, h, w = lr.shape
|
||||||
input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device)
|
|
||||||
|
|
||||||
rgb = lr
|
|
||||||
styles = styles.transpose(0, 1)
|
|
||||||
|
|
||||||
x = self.encoder(lr)
|
x = self.encoder(lr)
|
||||||
|
|
||||||
|
styles = styles.transpose(0, 1)
|
||||||
|
input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device)
|
||||||
|
if h != x.shape[-2]:
|
||||||
|
rgb = F.interpolate(lr, size=x.shape[2:], mode="area")
|
||||||
|
else:
|
||||||
|
rgb = lr
|
||||||
for style, block in zip(styles, self.blocks):
|
for style, block in zip(styles, self.blocks):
|
||||||
x, rgb = checkpoint(block, x, rgb, style, input_noise)
|
x, rgb = checkpoint(block, x, rgb, style, input_noise)
|
||||||
|
|
||||||
|
@ -153,8 +155,8 @@ class StyledSrGenerator(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gen = StyledSrGenerator(128)
|
gen = StyledSrGenerator(128, 2)
|
||||||
out = gen(torch.rand(1,3,32,32))
|
out = gen(torch.rand(1,3,64,64))
|
||||||
print([o.shape for o in out])
|
print([o.shape for o in out])
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user