forked from mrq/DL-Art-School
Alter how structural guidance is given to stylegan
This commit is contained in:
parent
3397c83447
commit
c9258e2da3
|
@ -428,8 +428,7 @@ class GeneratorBlock(nn.Module):
|
|||
x = self.upsample(x)
|
||||
|
||||
if self.structure_input:
|
||||
s = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest")
|
||||
s = self.structure_conv(s)
|
||||
s = self.structure_conv(structure_input)
|
||||
x = torch.cat([x, s], dim=1)
|
||||
|
||||
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
|
||||
|
@ -535,10 +534,24 @@ class Generator(nn.Module):
|
|||
styles = styles.transpose(0, 1)
|
||||
x = self.initial_conv(x)
|
||||
|
||||
if structure_input is not None:
|
||||
s = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest")
|
||||
for style, block, attn in zip(styles, self.blocks, self.attns):
|
||||
if exists(attn):
|
||||
x = checkpoint(attn, x)
|
||||
x, rgb = checkpoint(block, x, rgb, style, input_noise, structure_input)
|
||||
if structure_input is not None:
|
||||
if exists(block.upsample):
|
||||
# In this case, the structural guidance is given by the extra information over the previous layer.
|
||||
twoX = (x.shape[2]*2, x.shape[3]*2)
|
||||
sn = torch.nn.functional.interpolate(structure_input, size=twoX, mode="nearest")
|
||||
s_int = torch.nn.functional.interpolate(s, size=twoX, mode="bilinear")
|
||||
s_diff = sn - s_int
|
||||
else:
|
||||
# This is the initial case - just feed in the base structure.
|
||||
s_diff = s
|
||||
else:
|
||||
s_diff = None
|
||||
x, rgb = checkpoint(block, x, rgb, style, input_noise, s_diff)
|
||||
|
||||
return rgb
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user