Rework stylegan_for_sr to incorporate structure as an adain block

This commit is contained in:
James Betker 2020-11-23 11:31:11 -07:00
parent 519ba6f10c
commit b10bcf6436
9 changed files with 209 additions and 42 deletions

View File

@ -222,4 +222,3 @@ class RRDBNet(nn.Module):
for i, bm in enumerate(self.body):
if hasattr(bm, 'bypass_map'):
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import models.archs.srflow_orig.module_util as mutil
from models.archs.arch_util import default_init_weights, ConvGnSilu
from models.archs.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu
from utils.util import opt_get
@ -231,3 +231,27 @@ class RRDBNet(nn.Module):
return results
else:
return out
class RRDBLatentWrapper(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, with_bypass, blocks, pretrain_rrdb_path=None, gc=32, scale=4):
super().__init__()
self.with_bypass = with_bypass
self.blocks = blocks
fake_opt = { 'networks': {'generator': {'flow': {'stackRRDB': {'blocks': blocks}}, 'rrdb_bypass': with_bypass}}}
self.wrappedRRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, fake_opt)
if pretrain_rrdb_path is not None:
rrdb_state_dict = torch.load(pretrain_rrdb_path)
self.wrappedRRDB.load_state_dict(rrdb_state_dict, strict=True)
out_dim = nf * (len(blocks) + 1)
self.postprocess = nn.Sequential(ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=True, norm=True),
ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=True, norm=True),
ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=False, norm=False))
def forward(self, lr):
rrdbResults = self.wrappedRRDB(lr, get_steps=True)
blocklist = [rrdbResults["block_{}".format(idx)] for idx in self.blocks]
blocklist.append(rrdbResults['last_lr_fea'])
fea = torch.cat(blocklist, dim=1)
fea = self.postprocess(fea)
return fea

View File

@ -1,3 +1,4 @@
import functools
import math
import multiprocessing
from contextlib import contextmanager, ExitStack
@ -371,6 +372,76 @@ class RGBBlock(nn.Module):
return x
class AdaptiveInstanceNorm(nn.Module):
def __init__(self, in_channel, style_dim):
super().__init__()
from models.archs.arch_util import ConvGnLelu
self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True)
self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0)
self.norm = nn.InstanceNorm2d(in_channel)
def forward(self, input, style):
gamma = self.style2scale(style)
beta = self.style2bias(style)
out = self.norm(input)
out = gamma * out + beta
return out
class NoiseInjection(nn.Module):
def __init__(self, channel):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
def forward(self, image, noise):
return image + self.weight * noise
class EqualLR:
def __init__(self, name):
self.name = name
def compute_weight(self, module):
weight = getattr(module, self.name + '_orig')
fan_in = weight.data.size(1) * weight.data[0][0].numel()
return weight * math.sqrt(2 / fan_in)
@staticmethod
def apply(module, name):
fn = EqualLR(name)
weight = getattr(module, name)
del module._parameters[name]
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
module.register_forward_pre_hook(fn)
return fn
def __call__(self, module, input):
weight = self.compute_weight(module)
setattr(module, self.name, weight)
def equal_lr(module, name='weight'):
EqualLR.apply(module, name)
return module
class EqualConv2d(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
conv = nn.Conv2d(*args, **kwargs)
conv.weight.data.normal_()
conv.bias.data.zero_()
self.conv = equal_lr(conv)
def forward(self, input):
return self.conv(input)
class Conv2DMod(nn.Module):
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs):
super().__init__()
@ -408,6 +479,54 @@ class Conv2DMod(nn.Module):
return x
class GeneratorBlockWithStructure(nn.Module):
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
# Uses stylegan1 style blocks for injecting structural latent.
self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1)
self.to_noise0 = nn.Linear(1, filters)
self.noise0 = equal_lr(NoiseInjection(filters))
self.adain0 = AdaptiveInstanceNorm(filters, latent_dim)
self.to_style1 = nn.Linear(latent_dim, filters)
self.to_noise1 = nn.Linear(1, filters)
self.conv1 = Conv2DMod(filters, filters, 3)
self.to_style2 = nn.Linear(latent_dim, filters)
self.to_noise2 = nn.Linear(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3)
self.activation = leaky_relu()
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
def forward(self, x, prev_rgb, istyle, inoise, structure_input):
if exists(self.upsample):
x = self.upsample(x)
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
noise0 = self.to_noise0(inoise).permute((0, 3, 1, 2))
noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2))
noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2))
structure = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest")
x = self.conv0(x)
x = self.noise0(x, noise0)
x = self.adain0(x, structure)
style1 = self.to_style1(istyle)
x = self.conv1(x, style1)
x = self.activation(x + noise1)
style2 = self.to_style2(istyle)
x = self.conv2(x, style2)
x = self.activation(x + noise2)
rgb = self.to_rgb(x, prev_rgb, istyle)
return x, rgb
class GeneratorBlock(nn.Module):
def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False):
super().__init__()
@ -453,32 +572,6 @@ class GeneratorBlock(nn.Module):
return x, rgb
class DiscriminatorBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True):
super().__init__()
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
self.net = nn.Sequential(
nn.Conv2d(input_channels, filters, 3, padding=1),
leaky_relu(),
nn.Conv2d(filters, filters, 3, padding=1),
leaky_relu()
)
self.downsample = nn.Sequential(
Blur(),
nn.Conv2d(filters, filters, 3, padding=1, stride=2)
) if downsample else None
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
if exists(self.downsample):
x = self.downsample(x)
x = (x + res) * (1 / math.sqrt(2))
return x
class Generator(nn.Module):
def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False,
fmap_max=512, structure_input=False):
@ -515,18 +608,22 @@ class Generator(nn.Module):
self.attns.append(attn_fn)
block = GeneratorBlock(
if structure_input:
block_fn = GeneratorBlockWithStructure
else:
block_fn = GeneratorBlock
block = block_fn(
latent_dim,
in_chan,
out_chan,
upsample=not_first,
upsample_rgb=not_last,
rgba=transparent,
structure_input=structure_input
rgba=transparent
)
self.blocks.append(block)
def forward(self, styles, input_noise, structure_input=None):
def forward(self, styles, input_noise, structure_input=None, starting_shape=None):
batch_size = styles.shape[0]
image_size = self.image_size
@ -535,6 +632,8 @@ class Generator(nn.Module):
x = self.to_initial_block(avg_style)
else:
x = self.initial_block.expand(batch_size, -1, -1, -1)
if starting_shape is not None:
x = F.interpolate(x, size=starting_shape, mode="bilinear")
rgb = None
styles = styles.transpose(0, 1)
@ -591,7 +690,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
# To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format:
# b,f,h,w.
def forward(self, x, structure_input=None):
def forward(self, x, structure_input=None, fit_starting_shape_to_structure=False):
b, f, h, w = x.shape
full_random_latents = True
@ -614,12 +713,15 @@ class StyleGan2GeneratorWithLatent(nn.Module):
w_space = self.latent_to_w(self.vectorizer, style)
w_styles = self.styles_def_to_tensor(w_space)
starting_shape = None
if fit_starting_shape_to_structure:
starting_shape = (x.shape[2] // 32, x.shape[3] // 32)
# The underlying model expects the noise as b,h,w,1. Make it so.
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input), w_styles
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input, starting_shape), w_styles
def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for block in self.gen.blocks:
@ -629,6 +731,32 @@ class StyleGan2GeneratorWithLatent(nn.Module):
nn.init.zeros_(block.to_noise2.bias)
class DiscriminatorBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True):
super().__init__()
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
self.net = nn.Sequential(
nn.Conv2d(input_channels, filters, 3, padding=1),
leaky_relu(),
nn.Conv2d(filters, filters, 3, padding=1),
leaky_relu()
)
self.downsample = nn.Sequential(
Blur(),
nn.Conv2d(filters, filters, 3, padding=1, stride=2)
) if downsample else None
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
if exists(self.downsample):
x = self.downsample(x)
x = (x + res) * (1 / math.sqrt(2))
return x
class StyleGan2Discriminator(nn.Module):
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
transparent=False, fmap_max=512, input_filters=3):

View File

@ -22,6 +22,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
self.im_sz = opt_eval['image_size']
self.scale = opt_eval['scale']
self.fid_real_samples = opt_eval['real_fid_path']
self.embedding_generator = opt_eval['embedding_generator']
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
self.dataset = Stylegan2Dataset({'path': self.fid_real_samples,
'target_size': self.im_sz,
@ -30,6 +31,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
self.sampler = BatchSampler(self.dataset, self.batch_sz, False)
def perform_eval(self):
embedding_generator = self.env['generators'][self.embedding_generator]
fid_fake_path = osp.join(self.env['base_path'], "..", "fid_fake", str(self.env["step"]))
os.makedirs(fid_fake_path, exist_ok=True)
fid_real_path = osp.join(self.env['base_path'], "..", "fid_real", str(self.env["step"]))
@ -40,7 +42,8 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
batch_hq = [e['GT'] for e in batch]
batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device'])
resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area")
gen = self.model(noise, resized_batch)
embedding = embedding_generator(resized_batch)
gen = self.model(noise, embedding)
if not isinstance(gen, list) and not isinstance(gen, tuple):
gen = [gen]
gen = gen[self.gen_output_index]

View File

@ -148,6 +148,11 @@ def define_G(opt, opt_net, scale=None):
from models.archs.srflow_orig import SRFlowNet_arch
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'],
K=opt_net['K'], opt=opt)
elif which_model == 'rrdb_latent_wrapper':
from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper
netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG

View File

@ -78,15 +78,20 @@ class Injector(torch.nn.Module):
class ImageGeneratorInjector(Injector):
def __init__(self, opt, env):
super(ImageGeneratorInjector, self).__init__(opt, env)
self.grad = opt['grad'] if 'grad' in opt.keys() else True
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
with autocast(enabled=self.env['opt']['fp16']):
if isinstance(self.input, list):
params = extract_params_from_state(self.input, state)
else:
params = [state[self.input]]
if self.grad:
results = gen(*params)
else:
results = gen(state[self.input])
with torch.no_grad():
results = gen(*params)
new_state = {}
if isinstance(self.output, list):
# Only dereference tuples or lists, not tensors.

View File

@ -13,7 +13,7 @@ import torch
def main():
split_img = False
opt = {}
opt['n_thread'] = 5
opt['n_thread'] = 20
opt['compression_level'] = 90 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
@ -46,6 +46,9 @@ class TiledDataset(data.Dataset):
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
# Greyscale not supported.
if img is None:
print("Error with ", path)
return None
if len(img.shape) == 2:
return None
h, w, c = img.shape

View File

@ -31,8 +31,8 @@ class Trainer:
def init(self, opt, launcher, all_networks={}):
self._profile = False
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'].keys() else True
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'].keys() else True
#### loading resume state if exists
if opt['path'].get('resume_state', None):
@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srflow.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_v2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_srflow.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()