Alter how structural guidance is given to stylegan

This commit is contained in:
James Betker 2020-11-14 20:15:48 -07:00
parent 3397c83447
commit c9258e2da3

View File

@ -428,8 +428,7 @@ class GeneratorBlock(nn.Module):
x = self.upsample(x) x = self.upsample(x)
if self.structure_input: if self.structure_input:
s = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest") s = self.structure_conv(structure_input)
s = self.structure_conv(s)
x = torch.cat([x, s], dim=1) x = torch.cat([x, s], dim=1)
inoise = inoise[:, :x.shape[2], :x.shape[3], :] inoise = inoise[:, :x.shape[2], :x.shape[3], :]
@ -535,10 +534,24 @@ class Generator(nn.Module):
styles = styles.transpose(0, 1) styles = styles.transpose(0, 1)
x = self.initial_conv(x) x = self.initial_conv(x)
if structure_input is not None:
s = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest")
for style, block, attn in zip(styles, self.blocks, self.attns): for style, block, attn in zip(styles, self.blocks, self.attns):
if exists(attn): if exists(attn):
x = checkpoint(attn, x) x = checkpoint(attn, x)
x, rgb = checkpoint(block, x, rgb, style, input_noise, structure_input) if structure_input is not None:
if exists(block.upsample):
# In this case, the structural guidance is given by the extra information over the previous layer.
twoX = (x.shape[2]*2, x.shape[3]*2)
sn = torch.nn.functional.interpolate(structure_input, size=twoX, mode="nearest")
s_int = torch.nn.functional.interpolate(s, size=twoX, mode="bilinear")
s_diff = sn - s_int
else:
# This is the initial case - just feed in the base structure.
s_diff = s
else:
s_diff = None
x, rgb = checkpoint(block, x, rgb, style, input_noise, s_diff)
return rgb return rgb