Mods to support stylegan2 in SR mode

This commit is contained in:
James Betker 2020-11-13 20:11:50 -07:00
parent 67bf55495b
commit f406a5dd4c
5 changed files with 62 additions and 18 deletions

View File

@ -403,10 +403,15 @@ class Conv2DMod(nn.Module):
class GeneratorBlock(nn.Module):
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False):
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
self.structure_input = structure_input
if self.structure_input:
self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1)
input_channels = input_channels * 2
self.to_style1 = nn.Linear(latent_dim, input_channels)
self.to_noise1 = nn.Linear(1, filters)
self.conv1 = Conv2DMod(input_channels, filters, 3)
@ -418,10 +423,15 @@ class GeneratorBlock(nn.Module):
self.activation = leaky_relu()
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
def forward(self, x, prev_rgb, istyle, inoise):
def forward(self, x, prev_rgb, istyle, inoise, structure_input=None):
if exists(self.upsample):
x = self.upsample(x)
if self.structure_input:
s = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest")
s = self.structure_conv(s)
x = torch.cat([x, s], dim=1)
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))
@ -466,7 +476,7 @@ class DiscriminatorBlock(nn.Module):
class Generator(nn.Module):
def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False,
fmap_max=512):
fmap_max=512, structure_input=False):
super().__init__()
self.image_size = image_size
self.latent_dim = latent_dim
@ -506,11 +516,12 @@ class Generator(nn.Module):
out_chan,
upsample=not_first,
upsample_rgb=not_last,
rgba=transparent
rgba=transparent,
structure_input=structure_input
)
self.blocks.append(block)
def forward(self, styles, input_noise):
def forward(self, styles, input_noise, structure_input=None):
batch_size = styles.shape[0]
image_size = self.image_size
@ -527,17 +538,19 @@ class Generator(nn.Module):
for style, block, attn in zip(styles, self.blocks, self.attns):
if exists(attn):
x = attn(x)
x, rgb = checkpoint(block, x, rgb, style, input_noise)
x, rgb = checkpoint(block, x, rgb, style, input_noise, structure_input)
return rgb
# Wrapper that combines style vectorizer with the actual generator.
class StyleGan2GeneratorWithLatent(nn.Module):
def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1, network_capacity=16, transparent=False, attn_layers=[], no_const=False, fmap_max=512):
def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1, network_capacity=16, transparent=False,
attn_layers=[], no_const=False, fmap_max=512, structure_input=False):
super().__init__()
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)
self.gen = Generator(image_size, latent_dim, network_capacity, transparent, attn_layers, no_const, fmap_max)
self.gen = Generator(image_size, latent_dim, network_capacity, transparent, attn_layers, no_const, fmap_max,
structure_input=structure_input)
self.mixed_prob = .9
self._init_weights()
@ -559,7 +572,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
# To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format:
# b,f,h,w.
def forward(self, x):
def forward(self, x, structure_input=None):
b, f, h, w = x.shape
full_random_latents = True
@ -583,7 +596,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
w_styles = self.styles_def_to_tensor(w_space)
# The underlying model expects the noise as b,h,w,1. Make it so.
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3)), w_styles
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input), w_styles
def _init_weights(self):
for m in self.modules():
@ -599,13 +612,12 @@ class StyleGan2GeneratorWithLatent(nn.Module):
class StyleGan2Discriminator(nn.Module):
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
transparent=False, fmap_max=512):
transparent=False, fmap_max=512, input_filters=3):
super().__init__()
num_layers = int(log2(image_size) - 1)
num_init_filters = 3 if not transparent else 4
blocks = []
filters = [num_init_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)]
filters = [input_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)]
set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))

View File

@ -133,8 +133,9 @@ def define_G(opt, net_key='network_G', scale=None):
elif which_model == "linear_latent_estimator":
netG = LinearLatentEstimator(in_nc=3, nf=opt_net['nf'])
elif which_model == 'stylegan2':
is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
netG = StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
style_depth=opt_net['style_depth'])
style_depth=opt_net['style_depth'], structure_input=is_structured)
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG
@ -194,7 +195,7 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
elif which_model == "pyramid_disc":
netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf'])
elif which_model == "stylegan2_discriminator":
disc = StyleGan2Discriminator(image_size=opt_net['image_size'])
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'])
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))

View File

@ -54,6 +54,8 @@ def create_injector(opt_inject, env):
return PsnrInjector(opt_inject, env)
elif type == 'batch_rotate':
return BatchRotateInjector(opt_inject, env)
elif type == 'sr_diffs':
return SrDiffsInjector(opt_inject, env)
else:
raise NotImplementedError
@ -379,3 +381,25 @@ class BatchRotateInjector(Injector):
img = state[self.input]
return {self.output: torch.roll(img, 1, 0)}
# Injector used to work with image deltas used in diff-SR
class SrDiffsInjector(Injector):
def __init__(self, opt, env):
super(SrDiffsInjector, self).__init__(opt, env)
self.mode = opt['mode']
assert self.mode in ['recombine', 'produce_diff']
self.lq = opt['lq']
self.hq = opt['hq']
if self.mode == 'produce_diff':
self.diff_key = opt['diff']
def forward(self, state):
resampled_lq = state[self.lq]
hq = state[self.hq]
if self.mode == 'produce_diff':
diff = hq - resampled_lq
return {self.output: torch.cat([resampled_lq, diff], dim=1),
self.diff_key: diff}
elif self.mode == 'recombine':
combined = resampled_lq + hq
return {self.output: combined}

View File

@ -497,14 +497,21 @@ class StyleGan2DivergenceLoss(ConfigurableLoss):
self.discriminator = opt['discriminator']
self.for_gen = opt['gen_loss']
self.gp_frequency = opt['gradient_penalty_frequency']
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
def forward(self, net, state):
real_input = state[self.real]
fake_input = state[self.fake]
if self.noise != 0:
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
real_input = real_input + torch.rand_like(real_input) * self.noise
D = self.env['discriminators'][self.discriminator]
fake = D(state[self.fake])
fake = D(fake_input)
if self.for_gen:
return fake.mean()
else:
real_input = state[self.real].requires_grad_() # <-- Needed to compute gradients on the input.
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
real = D(real_input)
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_faster.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()