diff --git a/codes/models/archs/stylegan2.py b/codes/models/archs/stylegan2.py index a1c5a1cb..a676ef95 100644 --- a/codes/models/archs/stylegan2.py +++ b/codes/models/archs/stylegan2.py @@ -428,8 +428,7 @@ class GeneratorBlock(nn.Module): x = self.upsample(x) if self.structure_input: - s = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest") - s = self.structure_conv(s) + s = self.structure_conv(structure_input) x = torch.cat([x, s], dim=1) inoise = inoise[:, :x.shape[2], :x.shape[3], :] @@ -535,10 +534,24 @@ class Generator(nn.Module): styles = styles.transpose(0, 1) x = self.initial_conv(x) + if structure_input is not None: + s = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest") for style, block, attn in zip(styles, self.blocks, self.attns): if exists(attn): x = checkpoint(attn, x) - x, rgb = checkpoint(block, x, rgb, style, input_noise, structure_input) + if structure_input is not None: + if exists(block.upsample): + # In this case, the structural guidance is given by the extra information over the previous layer. + twoX = (x.shape[2]*2, x.shape[3]*2) + sn = torch.nn.functional.interpolate(structure_input, size=twoX, mode="nearest") + s_int = torch.nn.functional.interpolate(s, size=twoX, mode="bilinear") + s_diff = sn - s_int + else: + # This is the initial case - just feed in the base structure. + s_diff = s + else: + s_diff = None + x, rgb = checkpoint(block, x, rgb, style, input_noise, s_diff) return rgb