diff --git a/codes/models/improve_rrdb/styled_sr.py b/codes/models/improve_rrdb/styled_sr.py index dea13bc2..da080204 100644 --- a/codes/models/improve_rrdb/styled_sr.py +++ b/codes/models/improve_rrdb/styled_sr.py @@ -94,12 +94,14 @@ class Generator(nn.Module): def forward(self, lr, styles): b, c, h, w = lr.shape - input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device) - - rgb = lr - styles = styles.transpose(0, 1) - x = self.encoder(lr) + + styles = styles.transpose(0, 1) + input_noise = torch.rand(b, h * self.scale, w * self.scale, 1).to(lr.device) + if h != x.shape[-2]: + rgb = F.interpolate(lr, size=x.shape[2:], mode="area") + else: + rgb = lr for style, block in zip(styles, self.blocks): x, rgb = checkpoint(block, x, rgb, style, input_noise) @@ -153,8 +155,8 @@ class StyledSrGenerator(nn.Module): if __name__ == '__main__': - gen = StyledSrGenerator(128) - out = gen(torch.rand(1,3,32,32)) + gen = StyledSrGenerator(128, 2) + out = gen(torch.rand(1,3,64,64)) print([o.shape for o in out])