diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index a7fec89d..cddd1c4f 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -222,4 +222,3 @@ class RRDBNet(nn.Module): for i, bm in enumerate(self.body): if hasattr(bm, 'bypass_map'): torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1))) - diff --git a/codes/models/archs/srflow_orig/RRDBNet_arch.py b/codes/models/archs/srflow_orig/RRDBNet_arch.py index 607f4e0e..3828d7b8 100644 --- a/codes/models/archs/srflow_orig/RRDBNet_arch.py +++ b/codes/models/archs/srflow_orig/RRDBNet_arch.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import models.archs.srflow_orig.module_util as mutil -from models.archs.arch_util import default_init_weights, ConvGnSilu +from models.archs.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu from utils.util import opt_get @@ -231,3 +231,27 @@ class RRDBNet(nn.Module): return results else: return out + + +class RRDBLatentWrapper(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, with_bypass, blocks, pretrain_rrdb_path=None, gc=32, scale=4): + super().__init__() + self.with_bypass = with_bypass + self.blocks = blocks + fake_opt = { 'networks': {'generator': {'flow': {'stackRRDB': {'blocks': blocks}}, 'rrdb_bypass': with_bypass}}} + self.wrappedRRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, fake_opt) + if pretrain_rrdb_path is not None: + rrdb_state_dict = torch.load(pretrain_rrdb_path) + self.wrappedRRDB.load_state_dict(rrdb_state_dict, strict=True) + out_dim = nf * (len(blocks) + 1) + self.postprocess = nn.Sequential(ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=True, norm=True), + ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=True, norm=True), + ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=False, norm=False)) + + def forward(self, lr): + rrdbResults = self.wrappedRRDB(lr, get_steps=True) + blocklist = [rrdbResults["block_{}".format(idx)] for idx in self.blocks] + blocklist.append(rrdbResults['last_lr_fea']) + fea = torch.cat(blocklist, dim=1) + fea = self.postprocess(fea) + return fea \ No newline at end of file diff --git a/codes/models/archs/stylegan/stylegan2.py b/codes/models/archs/stylegan/stylegan2.py index 16735af3..8cd9365c 100644 --- a/codes/models/archs/stylegan/stylegan2.py +++ b/codes/models/archs/stylegan/stylegan2.py @@ -1,3 +1,4 @@ +import functools import math import multiprocessing from contextlib import contextmanager, ExitStack @@ -371,6 +372,76 @@ class RGBBlock(nn.Module): return x +class AdaptiveInstanceNorm(nn.Module): + def __init__(self, in_channel, style_dim): + super().__init__() + from models.archs.arch_util import ConvGnLelu + self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True) + self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0) + self.norm = nn.InstanceNorm2d(in_channel) + + def forward(self, input, style): + gamma = self.style2scale(style) + beta = self.style2bias(style) + out = self.norm(input) + out = gamma * out + beta + return out + + +class NoiseInjection(nn.Module): + def __init__(self, channel): + super().__init__() + self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) + + def forward(self, image, noise): + return image + self.weight * noise + + +class EqualLR: + def __init__(self, name): + self.name = name + + def compute_weight(self, module): + weight = getattr(module, self.name + '_orig') + fan_in = weight.data.size(1) * weight.data[0][0].numel() + + return weight * math.sqrt(2 / fan_in) + + @staticmethod + def apply(module, name): + fn = EqualLR(name) + + weight = getattr(module, name) + del module._parameters[name] + module.register_parameter(name + '_orig', nn.Parameter(weight.data)) + module.register_forward_pre_hook(fn) + + return fn + + def __call__(self, module, input): + weight = self.compute_weight(module) + setattr(module, self.name, weight) + + +def equal_lr(module, name='weight'): + EqualLR.apply(module, name) + + return module + + +class EqualConv2d(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + conv = nn.Conv2d(*args, **kwargs) + conv.weight.data.normal_() + conv.bias.data.zero_() + self.conv = equal_lr(conv) + + def forward(self, input): + return self.conv(input) + + class Conv2DMod(nn.Module): def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, **kwargs): super().__init__() @@ -408,6 +479,54 @@ class Conv2DMod(nn.Module): return x +class GeneratorBlockWithStructure(nn.Module): + def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False): + super().__init__() + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None + + # Uses stylegan1 style blocks for injecting structural latent. + self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1) + self.to_noise0 = nn.Linear(1, filters) + self.noise0 = equal_lr(NoiseInjection(filters)) + self.adain0 = AdaptiveInstanceNorm(filters, latent_dim) + + self.to_style1 = nn.Linear(latent_dim, filters) + self.to_noise1 = nn.Linear(1, filters) + self.conv1 = Conv2DMod(filters, filters, 3) + + self.to_style2 = nn.Linear(latent_dim, filters) + self.to_noise2 = nn.Linear(1, filters) + self.conv2 = Conv2DMod(filters, filters, 3) + + self.activation = leaky_relu() + self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba) + + def forward(self, x, prev_rgb, istyle, inoise, structure_input): + if exists(self.upsample): + x = self.upsample(x) + + inoise = inoise[:, :x.shape[2], :x.shape[3], :] + noise0 = self.to_noise0(inoise).permute((0, 3, 1, 2)) + noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2)) + noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2)) + + structure = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest") + x = self.conv0(x) + x = self.noise0(x, noise0) + x = self.adain0(x, structure) + + style1 = self.to_style1(istyle) + x = self.conv1(x, style1) + x = self.activation(x + noise1) + + style2 = self.to_style2(istyle) + x = self.conv2(x, style2) + x = self.activation(x + noise2) + + rgb = self.to_rgb(x, prev_rgb, istyle) + return x, rgb + + class GeneratorBlock(nn.Module): def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False): super().__init__() @@ -453,32 +572,6 @@ class GeneratorBlock(nn.Module): return x, rgb -class DiscriminatorBlock(nn.Module): - def __init__(self, input_channels, filters, downsample=True): - super().__init__() - self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) - - self.net = nn.Sequential( - nn.Conv2d(input_channels, filters, 3, padding=1), - leaky_relu(), - nn.Conv2d(filters, filters, 3, padding=1), - leaky_relu() - ) - - self.downsample = nn.Sequential( - Blur(), - nn.Conv2d(filters, filters, 3, padding=1, stride=2) - ) if downsample else None - - def forward(self, x): - res = self.conv_res(x) - x = self.net(x) - if exists(self.downsample): - x = self.downsample(x) - x = (x + res) * (1 / math.sqrt(2)) - return x - - class Generator(nn.Module): def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False, fmap_max=512, structure_input=False): @@ -515,18 +608,22 @@ class Generator(nn.Module): self.attns.append(attn_fn) - block = GeneratorBlock( + if structure_input: + block_fn = GeneratorBlockWithStructure + else: + block_fn = GeneratorBlock + + block = block_fn( latent_dim, in_chan, out_chan, upsample=not_first, upsample_rgb=not_last, - rgba=transparent, - structure_input=structure_input + rgba=transparent ) self.blocks.append(block) - def forward(self, styles, input_noise, structure_input=None): + def forward(self, styles, input_noise, structure_input=None, starting_shape=None): batch_size = styles.shape[0] image_size = self.image_size @@ -535,6 +632,8 @@ class Generator(nn.Module): x = self.to_initial_block(avg_style) else: x = self.initial_block.expand(batch_size, -1, -1, -1) + if starting_shape is not None: + x = F.interpolate(x, size=starting_shape, mode="bilinear") rgb = None styles = styles.transpose(0, 1) @@ -591,7 +690,7 @@ class StyleGan2GeneratorWithLatent(nn.Module): # To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format: # b,f,h,w. - def forward(self, x, structure_input=None): + def forward(self, x, structure_input=None, fit_starting_shape_to_structure=False): b, f, h, w = x.shape full_random_latents = True @@ -614,12 +713,15 @@ class StyleGan2GeneratorWithLatent(nn.Module): w_space = self.latent_to_w(self.vectorizer, style) w_styles = self.styles_def_to_tensor(w_space) + starting_shape = None + if fit_starting_shape_to_structure: + starting_shape = (x.shape[2] // 32, x.shape[3] // 32) # The underlying model expects the noise as b,h,w,1. Make it so. - return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input), w_styles + return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input, starting_shape), w_styles def _init_weights(self): for m in self.modules(): - if type(m) in {nn.Conv2d, nn.Linear}: + if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'): nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') for block in self.gen.blocks: @@ -629,6 +731,32 @@ class StyleGan2GeneratorWithLatent(nn.Module): nn.init.zeros_(block.to_noise2.bias) +class DiscriminatorBlock(nn.Module): + def __init__(self, input_channels, filters, downsample=True): + super().__init__() + self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) + + self.net = nn.Sequential( + nn.Conv2d(input_channels, filters, 3, padding=1), + leaky_relu(), + nn.Conv2d(filters, filters, 3, padding=1), + leaky_relu() + ) + + self.downsample = nn.Sequential( + Blur(), + nn.Conv2d(filters, filters, 3, padding=1, stride=2) + ) if downsample else None + + def forward(self, x): + res = self.conv_res(x) + x = self.net(x) + if exists(self.downsample): + x = self.downsample(x) + x = (x + res) * (1 / math.sqrt(2)) + return x + + class StyleGan2Discriminator(nn.Module): def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], transparent=False, fmap_max=512, input_filters=3): diff --git a/codes/models/eval/sr_style.py b/codes/models/eval/sr_style.py index ca46fae8..45ca70ee 100644 --- a/codes/models/eval/sr_style.py +++ b/codes/models/eval/sr_style.py @@ -22,6 +22,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator): self.im_sz = opt_eval['image_size'] self.scale = opt_eval['scale'] self.fid_real_samples = opt_eval['real_fid_path'] + self.embedding_generator = opt_eval['embedding_generator'] self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 self.dataset = Stylegan2Dataset({'path': self.fid_real_samples, 'target_size': self.im_sz, @@ -30,6 +31,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator): self.sampler = BatchSampler(self.dataset, self.batch_sz, False) def perform_eval(self): + embedding_generator = self.env['generators'][self.embedding_generator] fid_fake_path = osp.join(self.env['base_path'], "..", "fid_fake", str(self.env["step"])) os.makedirs(fid_fake_path, exist_ok=True) fid_real_path = osp.join(self.env['base_path'], "..", "fid_real", str(self.env["step"])) @@ -40,7 +42,8 @@ class SrStyleTransferEvaluator(evaluator.Evaluator): batch_hq = [e['GT'] for e in batch] batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device']) resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area") - gen = self.model(noise, resized_batch) + embedding = embedding_generator(resized_batch) + gen = self.model(noise, embedding) if not isinstance(gen, list) and not isinstance(gen, tuple): gen = [gen] gen = gen[self.gen_output_index] diff --git a/codes/models/networks.py b/codes/models/networks.py index 75130989..41aa7ab2 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -148,6 +148,11 @@ def define_G(opt, opt_net, scale=None): from models.archs.srflow_orig import SRFlowNet_arch netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], K=opt_net['K'], opt=opt) + elif which_model == 'rrdb_latent_wrapper': + from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper + netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], + nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'], + blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path']) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 3cb9366a..06165aba 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -78,15 +78,20 @@ class Injector(torch.nn.Module): class ImageGeneratorInjector(Injector): def __init__(self, opt, env): super(ImageGeneratorInjector, self).__init__(opt, env) + self.grad = opt['grad'] if 'grad' in opt.keys() else True def forward(self, state): gen = self.env['generators'][self.opt['generator']] with autocast(enabled=self.env['opt']['fp16']): if isinstance(self.input, list): params = extract_params_from_state(self.input, state) + else: + params = [state[self.input]] + if self.grad: results = gen(*params) else: - results = gen(state[self.input]) + with torch.no_grad(): + results = gen(*params) new_state = {} if isinstance(self.output, list): # Only dereference tuples or lists, not tensors. diff --git a/codes/scripts/extract_square_images.py b/codes/scripts/extract_square_images.py index 1a32d912..f037733e 100644 --- a/codes/scripts/extract_square_images.py +++ b/codes/scripts/extract_square_images.py @@ -13,7 +13,7 @@ import torch def main(): split_img = False opt = {} - opt['n_thread'] = 5 + opt['n_thread'] = 20 opt['compression_level'] = 90 # JPEG compression quality rating. # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. @@ -46,6 +46,9 @@ class TiledDataset(data.Dataset): img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # Greyscale not supported. + if img is None: + print("Error with ", path) + return None if len(img.shape) == 2: return None h, w, c = img.shape diff --git a/codes/train.py b/codes/train.py index bb0be9cb..38c1bf2b 100644 --- a/codes/train.py +++ b/codes/train.py @@ -31,8 +31,8 @@ class Trainer: def init(self, opt, launcher, all_networks={}): self._profile = False - self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True - self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True + self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'].keys() else True + self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'].keys() else True #### loading resume state if exists if opt['path'].get('resume_state', None): @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srflow.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_v2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/train2.py b/codes/train2.py index be751986..ecef64c3 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_srflow.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()