From 06d1c62c5a18e497309ec5f8064e9abaffbc5592 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 3 Dec 2020 15:32:21 -0700 Subject: [PATCH] iGPT support! Sweeeeet --- .gitignore | 1 + codes/data/__init__.py | 2 + codes/data/torch_dataset.py | 31 ++++ codes/models/archs/transformers/igpt/gpt2.py | 150 +++++++++++++++++++ codes/models/networks.py | 14 +- codes/models/steps/injectors.py | 3 + codes/models/steps/losses.py | 14 ++ 7 files changed, 204 insertions(+), 11 deletions(-) create mode 100644 codes/data/torch_dataset.py create mode 100644 codes/models/archs/transformers/igpt/gpt2.py diff --git a/.gitignore b/.gitignore index 2d8e8277..8f87fb3c 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ datasets/* options/* codes/*.txt codes/wandb/* +data/* .vscode *.html diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 753ae987..8af03960 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -45,6 +45,8 @@ def create_dataset(dataset_opt): from data.stylegan2_dataset import Stylegan2Dataset as D elif mode == 'imagefolder': from data.image_folder_dataset import ImageFolderDataset as D + elif mode == 'torch_dataset': + from data.torch_dataset import TorchDataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) diff --git a/codes/data/torch_dataset.py b/codes/data/torch_dataset.py new file mode 100644 index 00000000..58ed34de --- /dev/null +++ b/codes/data/torch_dataset.py @@ -0,0 +1,31 @@ +import torch +from torch.utils.data import Dataset +import torchvision.transforms as T +from torchvision import datasets + +# Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer. +class TorchDataset(Dataset): + def __init__(self, opt): + DATASET_MAP = { + "mnist": datasets.MNIST, + "fmnist": datasets.FashionMNIST, + "cifar10": datasets.CIFAR10, + } + transforms = [] + if opt['flip']: + transforms.append(T.RandomHorizontalFlip()) + if opt['crop_sz']: + transforms.append(T.RandomCrop(opt['crop_sz'], padding=opt['padding'], padding_mode="reflect")) + transforms.append(T.ToTensor()) + transforms = T.Compose(transforms) + is_for_training = opt['test'] if 'test' in opt.keys() else True + self.dataset = DATASET_MAP[opt['dataset']](opt['datapath'], train=is_for_training, download=True, transform=transforms) + self.len = opt['fixed_len'] if 'fixed_len' in opt.keys() else len(self.dataset) + + def __getitem__(self, item): + underlying_item = self.dataset[item][0] + return {'LQ': underlying_item, 'GT': underlying_item, + 'LQ_path': str(item), 'GT_path': str(item)} + + def __len__(self): + return self.len diff --git a/codes/models/archs/transformers/igpt/gpt2.py b/codes/models/archs/transformers/igpt/gpt2.py new file mode 100644 index 00000000..39388b11 --- /dev/null +++ b/codes/models/archs/transformers/igpt/gpt2.py @@ -0,0 +1,150 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import torchvision +from models.steps.injectors import Injector +from utils.util import checkpoint + + +def create_injector(opt, env): + type = opt['type'] + if type == 'igpt_resolve': + return ResolveInjector(opt, env) + return None + + +class ResolveInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.gen = opt['generator'] + self.samples = opt['num_samples'] + self.temperature = opt['temperature'] + + def forward(self, state): + gen = self.env['generators'][self.opt['generator']].module + img = state[self.input] + b, c, h, w = img.shape + qimg = gen.quantize(img) + s, b = qimg.shape + qimg = qimg[:s//2, :] + output = qimg.repeat(1, self.samples) + + pad = torch.zeros(1, self.samples, dtype=torch.long).cuda() # to pad prev output + with torch.no_grad(): + for _ in range(s//2): + logits, _ = gen(torch.cat((output, pad), dim=0), already_quantized=True) + logits = logits[-1, :, :] / self.temperature + probs = F.softmax(logits, dim=-1) + pred = torch.multinomial(probs, num_samples=1).transpose(1, 0) + output = torch.cat((output, pred), dim=0) + output = gen.unquantize(output.reshape(h, w, -1)) + return {self.output: output.permute(2,3,0,1).contiguous()} + + +class Block(nn.Module): + def __init__(self, embed_dim, num_heads): + super(Block, self).__init__() + self.ln_1 = nn.LayerNorm(embed_dim) + self.ln_2 = nn.LayerNorm(embed_dim) + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.mlp = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 4), + nn.GELU(), + nn.Linear(embed_dim * 4, embed_dim), + ) + + def forward(self, x): + attn_mask = torch.full( + (len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype + ) + attn_mask = torch.triu(attn_mask, diagonal=1) + + x = self.ln_1(x) + a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False) + x = x + a + m = self.mlp(self.ln_2(x)) + x = x + m + return x + + +class iGPT2(nn.Module): + def __init__( + self, embed_dim, num_heads, num_layers, num_positions, num_vocab, centroids_file + ): + super().__init__() + + self.centroids = nn.Parameter( + torch.from_numpy(np.load(centroids_file)), requires_grad=False + ) + self.embed_dim = embed_dim + + # start of sequence token + self.sos = torch.nn.Parameter(torch.zeros(embed_dim)) + nn.init.normal_(self.sos) + + self.token_embeddings = nn.Embedding(num_vocab, embed_dim) + self.position_embeddings = nn.Embedding(num_positions, embed_dim) + + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append(Block(embed_dim, num_heads)) + + self.ln_f = nn.LayerNorm(embed_dim) + self.head = nn.Linear(embed_dim, num_vocab, bias=False) + self.clf_head = nn.Linear(embed_dim, 10) # Fixed num_classes, this is not a classifier. + + def squared_euclidean_distance(self, a, b): + b = torch.transpose(b, 0, 1) + a2 = torch.sum(torch.square(a), dim=1, keepdims=True) + b2 = torch.sum(torch.square(b), dim=0, keepdims=True) + ab = torch.matmul(a, b) + d = a2 - 2 * ab + b2 + return d + + def quantize(self, x): + b, c, h, w = x.shape + # [B, C, H, W] => [B, H, W, C] + x = x.permute(0, 2, 3, 1).contiguous() + x = x.view(-1, c) # flatten to pixels + d = self.squared_euclidean_distance(x, self.centroids) + x = torch.argmin(d, 1) + x = x.view(b, h, w) + + # Reshape output to [seq_len, batch]. + x = x.view(x.shape[0], -1) # flatten images into sequences + x = x.transpose(0, 1).contiguous() # to shape [seq len, batch] + return x + + def unquantize(self, x): + return self.centroids[x] + + def forward(self, x, already_quantized=False): + """ + Expect input as shape [b, c, h, w] + """ + + if not already_quantized: + x = self.quantize(x) + length, batch = x.shape + + h = self.token_embeddings(x) + + # prepend sos token + sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos + h = torch.cat([sos, h[:-1, :, :]], axis=0) + + # add positional embeddings + positions = torch.arange(length, device=x.device).unsqueeze(-1) + h = h + self.position_embeddings(positions).expand_as(h) + + # transformer + for layer in self.layers: + h = checkpoint(layer, h) + + h = self.ln_f(h) + + logits = self.head(h) + + return logits, x + diff --git a/codes/models/networks.py b/codes/models/networks.py index 8b4b462f..3205dd54 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -16,7 +16,6 @@ import models.archs.SRResNet_arch as SRResNet_arch import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch import models.archs.discriminator_vgg_arch as SRGAN_arch import models.archs.feature_arch as feature_arch -import models.archs.panet.panet as panet import models.archs.rcan as rcan from models.archs import srg2_classic from models.archs.biggan.biggan_discriminator import BigGanDiscriminator @@ -69,12 +68,6 @@ def define_G(opt, opt_net, scale=None): opt_net['n_colors'] = 3 args_obj = munchify(opt_net) netG = rcan.RCAN(args_obj) - elif which_model == 'panet': - #args: n_resblocks, res_scale, scale, n_feats - opt_net['rgb_range'] = 255 - opt_net['n_colors'] = 3 - args_obj = munchify(opt_net) - netG = panet.PANET(args_obj) elif which_model == "ConfigurableSwitchedResidualGenerator2": netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'], @@ -158,10 +151,9 @@ def define_G(opt, opt_net, scale=None): netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'], initial_conv_stride=opt_net['initial_stride']) - elif which_model == 'mdcn': - from models.archs.mdcn.mdcn import MDCN - args = munchify({'scale': opt_net['scale'], 'n_colors': 3, 'rgb_range': 1.0}) - netG = MDCN(args) + elif which_model == 'igpt2': + from models.archs.transformers.igpt.gpt2 import iGPT2 + netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file']) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index cc9f6e54..a447fa49 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -20,6 +20,9 @@ def create_injector(opt_inject, env): elif 'stereoscopic_' in type: from models.steps.stereoscopic import create_stereoscopic_injector return create_stereoscopic_injector(opt_inject, env) + elif 'igpt' in type: + from models.archs.transformers.igpt import gpt2 + return gpt2.create_injector(opt_inject, env) elif type == 'generator': return ImageGeneratorInjector(opt_inject, env) elif type == 'discriminator': diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 02908e3e..6c42a205 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -17,6 +17,8 @@ def create_loss(opt_loss, env): elif 'stylegan2_' in type: from models.archs.stylegan import create_stylegan2_loss return create_stylegan2_loss(opt_loss, env) + elif type == 'crossentropy': + return CrossEntropy(opt_loss, env) elif type == 'pix': return PixLoss(opt_loss, env) elif type == 'direct': @@ -89,6 +91,18 @@ def get_basic_criterion_for_name(name, device): raise NotImplementedError +class CrossEntropy(ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.opt = opt + self.ce = nn.CrossEntropyLoss() + + def forward(self, _, state): + labels = state[self.opt['labels']] + logits = state[self.opt['logits']] + return self.ce(logits.view(-1, logits.size(-1)), labels.view(-1)) + + class PixLoss(ConfigurableLoss): def __init__(self, opt, env): super(PixLoss, self).__init__(opt, env)