diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index faa68723..a7973bdb 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -8,6 +8,8 @@ import numpy as np import torch import os +from torchvision.transforms import Normalize + from data import util # Builds a dataset created from a simple folder containing a list of training/test/validation images. from data.image_corruptor import ImageCorruptor @@ -28,6 +30,13 @@ class ImageFolderDataset: # from the same video source. Search for 'fetch_alt_image' for more info. self.skip_lq = opt_get(opt, ['skip_lq'], False) self.disable_flip = opt_get(opt, ['disable_flip'], False) + if 'normalize' in opt.keys(): + if opt['normalize'] == 'stylegan2_norm': + self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + else: + raise Exception('Unsupported normalize') + else: + self.normalize = None assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors. if not isinstance(self.paths, list): self.paths = [self.paths] @@ -128,6 +137,8 @@ class ImageFolderDataset: # Convert to torch tensor hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() + if self.normalize: + hq = self.normalize(hq) out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]} diff --git a/codes/data/stylegan2_dataset.py b/codes/data/stylegan2_dataset.py index 4281267f..52dde7e2 100644 --- a/codes/data/stylegan2_dataset.py +++ b/codes/data/stylegan2_dataset.py @@ -70,7 +70,7 @@ class expand_greyscale(object): class Stylegan2Dataset(data.Dataset): def __init__(self, opt): super().__init__() - EXTS = ['jpg', 'jpeg', 'png'] + EXTS = ['jpg', 'jpeg', 'png', 'webp'] self.folder = opt['path'] self.image_size = opt['target_size'] self.paths = [p for ext in EXTS for p in Path(f'{self.folder}').glob(f'**/*.{ext}')] diff --git a/codes/models/stylegan/stylegan2_lucidrains.py b/codes/models/stylegan/stylegan2_lucidrains.py index f61f2b0f..2f08d1a9 100644 --- a/codes/models/stylegan/stylegan2_lucidrains.py +++ b/codes/models/stylegan/stylegan2_lucidrains.py @@ -248,13 +248,12 @@ def gradient_penalty(images, output, weight=10, return_structured_grads=False): return penalty def calc_pl_lengths(styles, images): - device = images.device num_pixels = images.shape[2] * images.shape[3] - pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels) + pl_noise = torch.randn_like(images) / math.sqrt(num_pixels) outputs = (images * pl_noise).sum() pl_grads = torch_grad(outputs=outputs, inputs=styles, - grad_outputs=torch.ones(outputs.shape, device=device), + grad_outputs=torch.ones_like(outputs), create_graph=True, retain_graph=True, only_inputs=True)[0] return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt() @@ -850,6 +849,7 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss): self.for_gen = opt['gen_loss'] self.gp_frequency = opt['gradient_penalty_frequency'] self.noise = opt['noise'] if 'noise' in opt.keys() else 0 + self.logistic = opt_get(opt, ['logistic'], False) # Applies a logistic curve to the output logits, which is what the StyleGAN2 authors used. def forward(self, net, state): real_input = state[self.real] @@ -861,11 +861,19 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss): D = self.env['discriminators'][self.discriminator] fake = D(fake_input) if self.for_gen: - return fake.mean() + if self.logistic: + return F.softplus(-fake).mean() + else: + return fake.mean() else: real_input.requires_grad_() # <-- Needed to compute gradients on the input. real = D(real_input) - divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean() + if self.logistic: + rl = F.softplus(-real).mean() + fl = F.softplus(fake).mean() + return fl + rl + else: + divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean() # Apply gradient penalty. TODO: migrate this elsewhere. if self.env['step'] % self.gp_frequency == 0: diff --git a/codes/models/stylegan/stylegan2_rosinality.py b/codes/models/stylegan/stylegan2_rosinality.py index ba327894..e97041dc 100644 --- a/codes/models/stylegan/stylegan2_rosinality.py +++ b/codes/models/stylegan/stylegan2_rosinality.py @@ -11,6 +11,10 @@ from torch.autograd import Function # Ops -> The rosinality repo uses native cuda kernels for fused LeakyReLUs and upsamplers. This version extracts the # "cpu" alternative code and uses that instead for compatibility reasons. +from trainer.networks import register_model +from utils.util import opt_get + + class FusedLeakyReLU(nn.Module): def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): super().__init__() @@ -609,11 +613,7 @@ class Generator(nn.Module): image = skip - if return_latents: - return image, latent - - else: - return image, None + return image, latent class ConvLayer(nn.Sequential): @@ -741,3 +741,14 @@ class Discriminator(nn.Module): out = self.final_linear(out) return out + + +@register_model +def register_stylegan2_rosinality_gen(opt_net, opt): + kw = opt_get(opt_net, ['kwargs'], {}) + return Generator(**kw) + +@register_model +def register_stylegan2_rosinality_disc(opt_net, opt): + kw = opt_get(opt_net, ['kwargs'], {}) + return Discriminator(**kw) diff --git a/codes/scripts/extract_square_images.py b/codes/scripts/extract_square_images.py index 0859fc41..bbb12aea 100644 --- a/codes/scripts/extract_square_images.py +++ b/codes/scripts/extract_square_images.py @@ -13,17 +13,17 @@ import torch def main(): split_img = False opt = {} - opt['n_thread'] = 3 + opt['n_thread'] = 5 opt['compression_level'] = 95 # JPEG compression quality rating. # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. opt['dest'] = 'file' - opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\working'] - opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped' - opt['imgsize'] = 1024 - opt['bottom_crop'] = .1 - opt['keep_folder'] = True + opt['input_folder'] = ['F:\\4k6k\\datasets\\images\\lsun\\lsun\\cats'] + opt['save_folder'] = 'F:\\4k6k\\datasets\\images\\lsun\\lsun\\cats\\cropped' + opt['imgsize'] = 256 + opt['bottom_crop'] = 0 + opt['keep_folder'] = False save_folder = opt['save_folder'] if not osp.exists(save_folder): @@ -83,7 +83,7 @@ class TiledDataset(data.Dataset): pts = os.path.split(pts[0]) output_folder = osp.join(self.opt['save_folder'], pts[-1]) os.makedirs(output_folder, exist_ok=True) - cv2.imwrite(osp.join(output_folder, basename), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) + cv2.imwrite(osp.join(output_folder, basename.replace('.webp', '.jpg')), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']]) return None def __len__(self): diff --git a/codes/scripts/stylegan2/convert_weights_lucidrains.py b/codes/scripts/stylegan2/convert_weights_lucidrains.py deleted file mode 100644 index bc9851f4..00000000 --- a/codes/scripts/stylegan2/convert_weights_lucidrains.py +++ /dev/null @@ -1,292 +0,0 @@ -# Converts from Tensorflow Stylegan2 weights to weights used by this model. -# Original source: https://raw.githubusercontent.com/rosinality/stylegan2-pytorch/master/convert_weight.py -# Adapted to lucidrains' Stylegan implementation. -# -# Also doesn't require you to install Tensorflow 1.15 or clone the nVidia repo. - -# THIS DOES NOT CURRENTLY WORK. -# It does transfer all weights from the stylegan model to the lucidrains one, but does not produce correct results. -# The rosinality script this was stolen from has some "odd" intracacies that may be at cause for this: for example -# weight "flipping" in the conv layers which I do not understand. It may also be because I botched some of the mods -# required to make the lucidrains implementation conformant. I'll (maybe) get back to this some day. - -import argparse -import os -import sys -import pickle -import math - -import torch -import numpy as np -from torchvision import utils - - -# Converts from the TF state_dict input provided into the vars originally expected from the rosinality converter. -from models.stylegan.stylegan2_lucidrains import StyleGan2GeneratorWithLatent - - -def get_vars(vars, source_name): - net_name = source_name.split('/')[0] - vars_as_tuple_list = vars[net_name]['variables'] - result_vars = {} - for t in vars_as_tuple_list: - result_vars[t[0]] = t[1] - return result_vars, source_name.replace(net_name + "/", "") - -def get_vars_direct(vars, source_name): - v, n = get_vars(vars, source_name) - return v[n] - - -def convert_modconv(vars, source_name, target_name, flip=False, numeral=1): - vars, source_name = get_vars(vars, source_name) - weight = vars[source_name + "/weight"] - mod_weight = vars[source_name + "/mod_weight"] - mod_bias = vars[source_name + "/mod_bias"] - noise = vars[source_name + "/noise_strength"] - bias = vars[source_name + "/bias"] - - dic = { - f"conv{numeral}.weight": weight.transpose((3, 2, 0, 1)), - f"to_style{numeral}.weight": mod_weight.transpose((1, 0)), - f"to_style{numeral}.bias": mod_bias + 1, - f"noise{numeral}_scale": np.array([noise]), - f"activation{numeral}.bias": bias, - } - - dic_torch = {} - - for k, v in dic.items(): - dic_torch[target_name + "." + k] = torch.from_numpy(v) - - if flip: - dic_torch[target_name + f".conv{numeral}.weight"] = torch.flip( - dic_torch[target_name + f".conv{numeral}.weight"], [2, 3] - ) - - return dic_torch - - -def convert_conv(vars, source_name, target_name, bias=True, start=0): - vars, source_name = get_vars(vars, source_name) - weight = vars[source_name + "/weight"] - - dic = {"weight": weight.transpose((3, 2, 0, 1))} - - if bias: - dic["bias"] = vars[source_name + "/bias"] - - dic_torch = {} - - dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"]) - - if bias: - dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"]) - - return dic_torch - - -def convert_torgb(vars, source_name, target_name): - vars, source_name = get_vars(vars, source_name) - weight = vars[source_name + "/weight"] - mod_weight = vars[source_name + "/mod_weight"] - mod_bias = vars[source_name + "/mod_bias"] - bias = vars[source_name + "/bias"] - - dic = { - "conv.weight": weight.transpose((3, 2, 0, 1)), - "to_style.weight": mod_weight.transpose((1, 0)), - "to_style.bias": mod_bias + 1, - # "bias": bias.reshape((1, 3, 1, 1)), TODO: where is this? - } - - dic_torch = {} - - for k, v in dic.items(): - dic_torch[target_name + "." + k] = torch.from_numpy(v) - - return dic_torch - - -def convert_dense(vars, source_name, target_name): - vars, source_name = get_vars(vars, source_name) - weight = vars[source_name + "/weight"] - bias = vars[source_name + "/bias"] - - dic = {"weight": weight.transpose((1, 0)), "bias": bias} - - dic_torch = {} - - for k, v in dic.items(): - dic_torch[target_name + "." + k] = torch.from_numpy(v) - - return dic_torch - - -def update(state_dict, new, strict=True): - - for k, v in new.items(): - if strict: - if k not in state_dict: - raise KeyError(k + " is not found") - - if v.shape != state_dict[k].shape: - raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}") - - state_dict[k] = v - - -def discriminator_fill_statedict(statedict, vars, size): - log_size = int(math.log(size, 2)) - - update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0")) - - conv_i = 1 - - for i in range(log_size - 2, 0, -1): - reso = 4 * 2 ** i - update( - statedict, - convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"), - ) - update( - statedict, - convert_conv( - vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1 - ), - ) - update( - statedict, - convert_conv( - vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False - ), - ) - conv_i += 1 - - update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv")) - update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0")) - update(statedict, convert_dense(vars, f"Output", "final_linear.1")) - - return statedict - - -def fill_statedict(state_dict, vars, size): - log_size = int(math.log(size, 2)) - - for i in range(8): - update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"vectorizer.net.{i}")) - - update( - state_dict, - { - "gen.initial_block": torch.from_numpy( - get_vars_direct(vars, "G_synthesis/4x4/Const/const") - ) - }, - ) - - for i in range(log_size - 1): - reso = 4 * 2 ** i - update( - state_dict, - convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"gen.blocks.{i}.to_rgb"), - ) - - update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "gen.blocks.0", numeral=1)) - - for i in range(1, log_size - 1): - reso = 4 * 2 ** i - update( - state_dict, - convert_modconv( - vars, - f"G_synthesis/{reso}x{reso}/Conv0_up", - f"gen.blocks.{i}", - #flip=True, # TODO: why?? - numeral=1 - ), - ) - update( - state_dict, - convert_modconv( - vars, f"G_synthesis/{reso}x{reso}/Conv1", f"gen.blocks.{i}", numeral=2 - ), - ) - - ''' - TODO: consider porting this, though I dont think it is necessary. - for i in range(0, (log_size - 2) * 2 + 1): - update( - state_dict, - { - f"noises.noise_{i}": torch.from_numpy( - get_vars_direct(vars, f"G_synthesis/noise{i}") - ) - }, - ) - ''' - - return state_dict - - -if __name__ == "__main__": - device = "cuda" - - parser = argparse.ArgumentParser( - description="Tensorflow to pytorch model checkpoint converter" - ) - parser.add_argument( - "--gen", action="store_true", help="convert the generator weights" - ) - parser.add_argument( - "--channel_multiplier", - type=int, - default=2, - help="channel multiplier factor. config-f = 2, else = 1", - ) - parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights") - - args = parser.parse_args() - sys.path.append('scripts\\stylegan2') - - import dnnlib - from dnnlib.tflib.network import generator, gen_ema - - with open(args.path, "rb") as f: - pickle.load(f) - - # Weight names are ordered by size. The last name will be something like '1024x1024/'. We just need to grab that first number. - size = int(generator['G_synthesis']['variables'][-1][0].split('x')[0]) - - g = StyleGan2GeneratorWithLatent(image_size=size, latent_dim=512, style_depth=8) - state_dict = g.state_dict() - state_dict = fill_statedict(state_dict, gen_ema, size) - - g.load_state_dict(state_dict, strict=True) - - latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg")) - - ckpt = {"g_ema": state_dict, "latent_avg": latent_avg} - - if args.gen: - g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) - g_train_state = g_train.state_dict() - g_train_state = fill_statedict(g_train_state, generator, size) - ckpt["g"] = g_train_state - - name = os.path.splitext(os.path.basename(args.path))[0] - torch.save(ckpt, name + ".pt") - - batch_size = {256: 16, 512: 9, 1024: 4} - n_sample = batch_size.get(size, 25) - - g = g.to(device) - - z = np.random.RandomState(5).randn(n_sample, 512).astype("float32") - - with torch.no_grad(): - img_pt, _ = g(8) - - utils.save_image( - img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1) - ) diff --git a/codes/scripts/stylegan2/convert_weights_rosinality.py b/codes/scripts/stylegan2/convert_weights_rosinality.py index a6106187..d0b505c7 100644 --- a/codes/scripts/stylegan2/convert_weights_rosinality.py +++ b/codes/scripts/stylegan2/convert_weights_rosinality.py @@ -123,7 +123,7 @@ def update(state_dict, new): def discriminator_fill_statedict(statedict, vars, size): log_size = int(math.log(size, 2)) - update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0")) + update(statedict, convert_conv(vars, f"D/{size}x{size}/FromRGB", "convs.0")) conv_i = 1 @@ -131,25 +131,25 @@ def discriminator_fill_statedict(statedict, vars, size): reso = 4 * 2 ** i update( statedict, - convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"), + convert_conv(vars, f"D/{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"), ) update( statedict, convert_conv( - vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1 + vars, f"D/{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1 ), ) update( statedict, convert_conv( - vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False + vars, f"D/{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False ), ) conv_i += 1 - update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv")) - update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0")) - update(statedict, convert_dense(vars, f"Output", "final_linear.1")) + update(statedict, convert_conv(vars, f"D/4x4/Conv", "final_conv")) + update(statedict, convert_dense(vars, f"D/4x4/Dense0", "final_linear.0")) + update(statedict, convert_dense(vars, f"D/Output", "final_linear.1")) return statedict @@ -249,9 +249,14 @@ if __name__ == "__main__": g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) state_dict = g.state_dict() state_dict = fill_statedict(state_dict, gen_ema, size) - g.load_state_dict(state_dict, strict=True) + d = Discriminator(size, args.channel_multiplier) + dstate_dict = d.state_dict() + dstate_dict = discriminator_fill_statedict(dstate_dict, discriminator, size) + d.load_state_dict(dstate_dict, strict=True) + + latent_avg = torch.from_numpy(get_vars_direct(gen_ema, "G/dlatent_avg")) ckpt = {"g_ema": state_dict, "latent_avg": latent_avg} @@ -269,14 +274,16 @@ if __name__ == "__main__": ckpt["d"] = d_state name = os.path.splitext(os.path.basename(args.path))[0] - torch.save(state_dict, name + ".pth") + torch.save(state_dict, f"{name}_gen.pth") + torch.save(dstate_dict, f"{name}_disc.pth") batch_size = {256: 16, 512: 9, 1024: 4} n_sample = batch_size.get(size, 25) g = g.to(device) + d = d.to(device) - z = np.random.RandomState(5).randn(n_sample, 512).astype("float32") + z = np.random.RandomState(1).randn(n_sample, 512).astype("float32") with torch.no_grad(): img_pt, _ = g( @@ -285,6 +292,8 @@ if __name__ == "__main__": truncation_latent=latent_avg.to(device), randomize_noise=False, ) + disc = d(img_pt) + print(disc) utils.save_image( img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1) diff --git a/codes/train.py b/codes/train.py index 515ea149..2fdfb4b3 100644 --- a/codes/train.py +++ b/codes/train.py @@ -295,7 +295,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_vqvae3_stage1.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cats_stylegan2_rosinality.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 3fffe510..e3bcb3f6 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -13,6 +13,8 @@ from trainer.steps import ConfigurableStep from trainer.experiments.experiments import get_experiment_for_name import torchvision.utils as utils +from utils.util import opt_get + logger = logging.getLogger('base') @@ -253,6 +255,8 @@ class ExtensibleTrainer(BaseModel): # Record visual outputs for usage in debugging and testing. if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0: + denorm = opt_get(self.opt, ['logger', 'denormalize'], False) + denorm_range = tuple(opt_get(self.opt, ['logger', 'denormalize_range'], None)) sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg") for v in self.opt['logger']['visuals']: if v not in state.keys(): @@ -264,12 +268,12 @@ class ExtensibleTrainer(BaseModel): if rdbgv.shape[1] > 3: rdbgv = rdbgv[:, :3, :, :] os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) - utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i))) + utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)), normalize=denorm, range=denorm_range) else: if dbgv.shape[1] > 3: dbgv = dbgv[:,:3,:,:] os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) - utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i))) + utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)), normalize=denorm, range=denorm_range) # Some models have their own specific visual debug routines. for net_name, net in self.networks.items(): if hasattr(net.module, "visual_dbg"): diff --git a/codes/trainer/eval/fid.py b/codes/trainer/eval/fid.py index 631f328e..6c662593 100644 --- a/codes/trainer/eval/fid.py +++ b/codes/trainer/eval/fid.py @@ -8,6 +8,9 @@ from pytorch_fid import fid_score # Evaluate that generates uniform noise to feed into a generator, then calculates a FID score on the results. +from utils.util import opt_get + + class StyleTransferEvaluator(evaluator.Evaluator): def __init__(self, model, opt_eval, env): super().__init__(model, opt_eval, env) @@ -16,13 +19,18 @@ class StyleTransferEvaluator(evaluator.Evaluator): self.im_sz = opt_eval['image_size'] self.fid_real_samples = opt_eval['real_fid_path'] self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 + self.noise_type = opt_get(opt_eval, ['noise_type'], 'imgnoise') + self.latent_dim = opt_get(opt_eval, ['latent_dim'], 512) # Not needed if using 'imgnoise' input. def perform_eval(self): fid_fake_path = osp.join(self.env['base_path'], "../../models", "fid", str(self.env["step"])) os.makedirs(fid_fake_path, exist_ok=True) counter = 0 for i in range(self.batches_per_eval): - batch = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device']) + if self.noise_type == 'imgnoise': + batch = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device']) + elif self.noise_type == 'stylenoise': + batch = [torch.randn(self.batch_sz, self.latent_dim).to(self.env['device'])] gen = self.model(batch) if not isinstance(gen, list) and not isinstance(gen, tuple): gen = [gen] diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index 28006fec..2acfca6b 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -6,6 +6,7 @@ from torch.cuda.amp import autocast from trainer.inject import Injector from trainer.losses import extract_params_from_state +from utils.util import opt_get from utils.weight_scheduler import get_scheduler_for_opt @@ -386,3 +387,19 @@ class RandomCropInjector(Injector): def forward(self, state): return {self.output: self.operator(state[self.input])} + +class Stylegan2NoiseInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.mix_prob = opt_get(opt, ['mix_probability'], .9) + self.latent_dim = opt_get(opt, ['latent_dim'], 512) + + def make_noise(self, batch, latent_dim, n_noise, device): + return torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) + + def forward(self, state): + i = state[self.input] + if self.mix_prob > 0 and random.random() < self.mix_prob: + return {self.output: self.make_noise(i.shape[0], self.latent_dim, 2, i.device)} + else: + return {self.output: self.make_noise(i.shape[0], self.latent_dim, 1, i.device)} \ No newline at end of file