forked from mrq/DL-Art-School
styled_sr: fix bug when using initial_stride
This commit is contained in:
parent
913fc3b75e
commit
f39179e85a
|
@ -94,12 +94,14 @@ class Generator(nn.Module):
|
|||
|
||||
def forward(self, lr, styles):
|
||||
b, c, h, w = lr.shape
|
||||
input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device)
|
||||
|
||||
rgb = lr
|
||||
styles = styles.transpose(0, 1)
|
||||
|
||||
x = self.encoder(lr)
|
||||
|
||||
styles = styles.transpose(0, 1)
|
||||
input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device)
|
||||
if h != x.shape[-2]:
|
||||
rgb = F.interpolate(lr, size=x.shape[2:], mode="area")
|
||||
else:
|
||||
rgb = lr
|
||||
for style, block in zip(styles, self.blocks):
|
||||
x, rgb = checkpoint(block, x, rgb, style, input_noise)
|
||||
|
||||
|
@ -153,8 +155,8 @@ class StyledSrGenerator(nn.Module):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gen = StyledSrGenerator(128)
|
||||
out = gen(torch.rand(1,3,32,32))
|
||||
gen = StyledSrGenerator(128, 2)
|
||||
out = gen(torch.rand(1,3,64,64))
|
||||
print([o.shape for o in out])
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user