Mods to support stylegan2 in SR mode
This commit is contained in:
parent
67bf55495b
commit
f406a5dd4c
|
@ -403,10 +403,15 @@ class Conv2DMod(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class GeneratorBlock(nn.Module):
|
class GeneratorBlock(nn.Module):
|
||||||
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False):
|
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
|
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
|
||||||
|
|
||||||
|
self.structure_input = structure_input
|
||||||
|
if self.structure_input:
|
||||||
|
self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1)
|
||||||
|
input_channels = input_channels * 2
|
||||||
|
|
||||||
self.to_style1 = nn.Linear(latent_dim, input_channels)
|
self.to_style1 = nn.Linear(latent_dim, input_channels)
|
||||||
self.to_noise1 = nn.Linear(1, filters)
|
self.to_noise1 = nn.Linear(1, filters)
|
||||||
self.conv1 = Conv2DMod(input_channels, filters, 3)
|
self.conv1 = Conv2DMod(input_channels, filters, 3)
|
||||||
|
@ -418,10 +423,15 @@ class GeneratorBlock(nn.Module):
|
||||||
self.activation = leaky_relu()
|
self.activation = leaky_relu()
|
||||||
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
|
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
|
||||||
|
|
||||||
def forward(self, x, prev_rgb, istyle, inoise):
|
def forward(self, x, prev_rgb, istyle, inoise, structure_input=None):
|
||||||
if exists(self.upsample):
|
if exists(self.upsample):
|
||||||
x = self.upsample(x)
|
x = self.upsample(x)
|
||||||
|
|
||||||
|
if self.structure_input:
|
||||||
|
s = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest")
|
||||||
|
s = self.structure_conv(s)
|
||||||
|
x = torch.cat([x, s], dim=1)
|
||||||
|
|
||||||
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
|
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
|
||||||
noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
|
noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
|
||||||
noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))
|
noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))
|
||||||
|
@ -466,7 +476,7 @@ class DiscriminatorBlock(nn.Module):
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False,
|
def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False,
|
||||||
fmap_max=512):
|
fmap_max=512, structure_input=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.latent_dim = latent_dim
|
self.latent_dim = latent_dim
|
||||||
|
@ -506,11 +516,12 @@ class Generator(nn.Module):
|
||||||
out_chan,
|
out_chan,
|
||||||
upsample=not_first,
|
upsample=not_first,
|
||||||
upsample_rgb=not_last,
|
upsample_rgb=not_last,
|
||||||
rgba=transparent
|
rgba=transparent,
|
||||||
|
structure_input=structure_input
|
||||||
)
|
)
|
||||||
self.blocks.append(block)
|
self.blocks.append(block)
|
||||||
|
|
||||||
def forward(self, styles, input_noise):
|
def forward(self, styles, input_noise, structure_input=None):
|
||||||
batch_size = styles.shape[0]
|
batch_size = styles.shape[0]
|
||||||
image_size = self.image_size
|
image_size = self.image_size
|
||||||
|
|
||||||
|
@ -527,17 +538,19 @@ class Generator(nn.Module):
|
||||||
for style, block, attn in zip(styles, self.blocks, self.attns):
|
for style, block, attn in zip(styles, self.blocks, self.attns):
|
||||||
if exists(attn):
|
if exists(attn):
|
||||||
x = attn(x)
|
x = attn(x)
|
||||||
x, rgb = checkpoint(block, x, rgb, style, input_noise)
|
x, rgb = checkpoint(block, x, rgb, style, input_noise, structure_input)
|
||||||
|
|
||||||
return rgb
|
return rgb
|
||||||
|
|
||||||
|
|
||||||
# Wrapper that combines style vectorizer with the actual generator.
|
# Wrapper that combines style vectorizer with the actual generator.
|
||||||
class StyleGan2GeneratorWithLatent(nn.Module):
|
class StyleGan2GeneratorWithLatent(nn.Module):
|
||||||
def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1, network_capacity=16, transparent=False, attn_layers=[], no_const=False, fmap_max=512):
|
def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1, network_capacity=16, transparent=False,
|
||||||
|
attn_layers=[], no_const=False, fmap_max=512, structure_input=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)
|
self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp)
|
||||||
self.gen = Generator(image_size, latent_dim, network_capacity, transparent, attn_layers, no_const, fmap_max)
|
self.gen = Generator(image_size, latent_dim, network_capacity, transparent, attn_layers, no_const, fmap_max,
|
||||||
|
structure_input=structure_input)
|
||||||
self.mixed_prob = .9
|
self.mixed_prob = .9
|
||||||
self._init_weights()
|
self._init_weights()
|
||||||
|
|
||||||
|
@ -559,7 +572,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
||||||
|
|
||||||
# To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format:
|
# To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format:
|
||||||
# b,f,h,w.
|
# b,f,h,w.
|
||||||
def forward(self, x):
|
def forward(self, x, structure_input=None):
|
||||||
b, f, h, w = x.shape
|
b, f, h, w = x.shape
|
||||||
|
|
||||||
full_random_latents = True
|
full_random_latents = True
|
||||||
|
@ -583,7 +596,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
||||||
w_styles = self.styles_def_to_tensor(w_space)
|
w_styles = self.styles_def_to_tensor(w_space)
|
||||||
|
|
||||||
# The underlying model expects the noise as b,h,w,1. Make it so.
|
# The underlying model expects the noise as b,h,w,1. Make it so.
|
||||||
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3)), w_styles
|
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input), w_styles
|
||||||
|
|
||||||
def _init_weights(self):
|
def _init_weights(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
|
@ -599,13 +612,12 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
||||||
|
|
||||||
class StyleGan2Discriminator(nn.Module):
|
class StyleGan2Discriminator(nn.Module):
|
||||||
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
|
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
|
||||||
transparent=False, fmap_max=512):
|
transparent=False, fmap_max=512, input_filters=3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
num_layers = int(log2(image_size) - 1)
|
num_layers = int(log2(image_size) - 1)
|
||||||
num_init_filters = 3 if not transparent else 4
|
|
||||||
|
|
||||||
blocks = []
|
blocks = []
|
||||||
filters = [num_init_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)]
|
filters = [input_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)]
|
||||||
|
|
||||||
set_fmap_max = partial(min, fmap_max)
|
set_fmap_max = partial(min, fmap_max)
|
||||||
filters = list(map(set_fmap_max, filters))
|
filters = list(map(set_fmap_max, filters))
|
||||||
|
|
|
@ -133,8 +133,9 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
elif which_model == "linear_latent_estimator":
|
elif which_model == "linear_latent_estimator":
|
||||||
netG = LinearLatentEstimator(in_nc=3, nf=opt_net['nf'])
|
netG = LinearLatentEstimator(in_nc=3, nf=opt_net['nf'])
|
||||||
elif which_model == 'stylegan2':
|
elif which_model == 'stylegan2':
|
||||||
|
is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
|
||||||
netG = StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
netG = StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
||||||
style_depth=opt_net['style_depth'])
|
style_depth=opt_net['style_depth'], structure_input=is_structured)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
return netG
|
return netG
|
||||||
|
@ -194,7 +195,7 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||||
elif which_model == "pyramid_disc":
|
elif which_model == "pyramid_disc":
|
||||||
netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf'])
|
netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf'])
|
||||||
elif which_model == "stylegan2_discriminator":
|
elif which_model == "stylegan2_discriminator":
|
||||||
disc = StyleGan2Discriminator(image_size=opt_net['image_size'])
|
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'])
|
||||||
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
netD = StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
|
|
|
@ -54,6 +54,8 @@ def create_injector(opt_inject, env):
|
||||||
return PsnrInjector(opt_inject, env)
|
return PsnrInjector(opt_inject, env)
|
||||||
elif type == 'batch_rotate':
|
elif type == 'batch_rotate':
|
||||||
return BatchRotateInjector(opt_inject, env)
|
return BatchRotateInjector(opt_inject, env)
|
||||||
|
elif type == 'sr_diffs':
|
||||||
|
return SrDiffsInjector(opt_inject, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -379,3 +381,25 @@ class BatchRotateInjector(Injector):
|
||||||
img = state[self.input]
|
img = state[self.input]
|
||||||
return {self.output: torch.roll(img, 1, 0)}
|
return {self.output: torch.roll(img, 1, 0)}
|
||||||
|
|
||||||
|
|
||||||
|
# Injector used to work with image deltas used in diff-SR
|
||||||
|
class SrDiffsInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(SrDiffsInjector, self).__init__(opt, env)
|
||||||
|
self.mode = opt['mode']
|
||||||
|
assert self.mode in ['recombine', 'produce_diff']
|
||||||
|
self.lq = opt['lq']
|
||||||
|
self.hq = opt['hq']
|
||||||
|
if self.mode == 'produce_diff':
|
||||||
|
self.diff_key = opt['diff']
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
resampled_lq = state[self.lq]
|
||||||
|
hq = state[self.hq]
|
||||||
|
if self.mode == 'produce_diff':
|
||||||
|
diff = hq - resampled_lq
|
||||||
|
return {self.output: torch.cat([resampled_lq, diff], dim=1),
|
||||||
|
self.diff_key: diff}
|
||||||
|
elif self.mode == 'recombine':
|
||||||
|
combined = resampled_lq + hq
|
||||||
|
return {self.output: combined}
|
||||||
|
|
|
@ -497,14 +497,21 @@ class StyleGan2DivergenceLoss(ConfigurableLoss):
|
||||||
self.discriminator = opt['discriminator']
|
self.discriminator = opt['discriminator']
|
||||||
self.for_gen = opt['gen_loss']
|
self.for_gen = opt['gen_loss']
|
||||||
self.gp_frequency = opt['gradient_penalty_frequency']
|
self.gp_frequency = opt['gradient_penalty_frequency']
|
||||||
|
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
|
real_input = state[self.real]
|
||||||
|
fake_input = state[self.fake]
|
||||||
|
if self.noise != 0:
|
||||||
|
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
|
||||||
|
real_input = real_input + torch.rand_like(real_input) * self.noise
|
||||||
|
|
||||||
D = self.env['discriminators'][self.discriminator]
|
D = self.env['discriminators'][self.discriminator]
|
||||||
fake = D(state[self.fake])
|
fake = D(fake_input)
|
||||||
if self.for_gen:
|
if self.for_gen:
|
||||||
return fake.mean()
|
return fake.mean()
|
||||||
else:
|
else:
|
||||||
real_input = state[self.real].requires_grad_() # <-- Needed to compute gradients on the input.
|
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
||||||
real = D(real_input)
|
real = D(real_input)
|
||||||
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
||||||
|
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_faster.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user