iGPT support!

Sweeeeet
This commit is contained in:
James Betker 2020-12-03 15:32:21 -07:00
parent c18adbd606
commit 06d1c62c5a
7 changed files with 204 additions and 11 deletions

1
.gitignore vendored
View File

@ -5,6 +5,7 @@ datasets/*
options/*
codes/*.txt
codes/wandb/*
data/*
.vscode
*.html

View File

@ -45,6 +45,8 @@ def create_dataset(dataset_opt):
from data.stylegan2_dataset import Stylegan2Dataset as D
elif mode == 'imagefolder':
from data.image_folder_dataset import ImageFolderDataset as D
elif mode == 'torch_dataset':
from data.torch_dataset import TorchDataset as D
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)

View File

@ -0,0 +1,31 @@
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
from torchvision import datasets
# Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer.
class TorchDataset(Dataset):
def __init__(self, opt):
DATASET_MAP = {
"mnist": datasets.MNIST,
"fmnist": datasets.FashionMNIST,
"cifar10": datasets.CIFAR10,
}
transforms = []
if opt['flip']:
transforms.append(T.RandomHorizontalFlip())
if opt['crop_sz']:
transforms.append(T.RandomCrop(opt['crop_sz'], padding=opt['padding'], padding_mode="reflect"))
transforms.append(T.ToTensor())
transforms = T.Compose(transforms)
is_for_training = opt['test'] if 'test' in opt.keys() else True
self.dataset = DATASET_MAP[opt['dataset']](opt['datapath'], train=is_for_training, download=True, transform=transforms)
self.len = opt['fixed_len'] if 'fixed_len' in opt.keys() else len(self.dataset)
def __getitem__(self, item):
underlying_item = self.dataset[item][0]
return {'LQ': underlying_item, 'GT': underlying_item,
'LQ_path': str(item), 'GT_path': str(item)}
def __len__(self):
return self.len

View File

@ -0,0 +1,150 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision
from models.steps.injectors import Injector
from utils.util import checkpoint
def create_injector(opt, env):
type = opt['type']
if type == 'igpt_resolve':
return ResolveInjector(opt, env)
return None
class ResolveInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.gen = opt['generator']
self.samples = opt['num_samples']
self.temperature = opt['temperature']
def forward(self, state):
gen = self.env['generators'][self.opt['generator']].module
img = state[self.input]
b, c, h, w = img.shape
qimg = gen.quantize(img)
s, b = qimg.shape
qimg = qimg[:s//2, :]
output = qimg.repeat(1, self.samples)
pad = torch.zeros(1, self.samples, dtype=torch.long).cuda() # to pad prev output
with torch.no_grad():
for _ in range(s//2):
logits, _ = gen(torch.cat((output, pad), dim=0), already_quantized=True)
logits = logits[-1, :, :] / self.temperature
probs = F.softmax(logits, dim=-1)
pred = torch.multinomial(probs, num_samples=1).transpose(1, 0)
output = torch.cat((output, pred), dim=0)
output = gen.unquantize(output.reshape(h, w, -1))
return {self.output: output.permute(2,3,0,1).contiguous()}
class Block(nn.Module):
def __init__(self, embed_dim, num_heads):
super(Block, self).__init__()
self.ln_1 = nn.LayerNorm(embed_dim)
self.ln_2 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Linear(embed_dim * 4, embed_dim),
)
def forward(self, x):
attn_mask = torch.full(
(len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
)
attn_mask = torch.triu(attn_mask, diagonal=1)
x = self.ln_1(x)
a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x
class iGPT2(nn.Module):
def __init__(
self, embed_dim, num_heads, num_layers, num_positions, num_vocab, centroids_file
):
super().__init__()
self.centroids = nn.Parameter(
torch.from_numpy(np.load(centroids_file)), requires_grad=False
)
self.embed_dim = embed_dim
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
self.token_embeddings = nn.Embedding(num_vocab, embed_dim)
self.position_embeddings = nn.Embedding(num_positions, embed_dim)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(Block(embed_dim, num_heads))
self.ln_f = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_vocab, bias=False)
self.clf_head = nn.Linear(embed_dim, 10) # Fixed num_classes, this is not a classifier.
def squared_euclidean_distance(self, a, b):
b = torch.transpose(b, 0, 1)
a2 = torch.sum(torch.square(a), dim=1, keepdims=True)
b2 = torch.sum(torch.square(b), dim=0, keepdims=True)
ab = torch.matmul(a, b)
d = a2 - 2 * ab + b2
return d
def quantize(self, x):
b, c, h, w = x.shape
# [B, C, H, W] => [B, H, W, C]
x = x.permute(0, 2, 3, 1).contiguous()
x = x.view(-1, c) # flatten to pixels
d = self.squared_euclidean_distance(x, self.centroids)
x = torch.argmin(d, 1)
x = x.view(b, h, w)
# Reshape output to [seq_len, batch].
x = x.view(x.shape[0], -1) # flatten images into sequences
x = x.transpose(0, 1).contiguous() # to shape [seq len, batch]
return x
def unquantize(self, x):
return self.centroids[x]
def forward(self, x, already_quantized=False):
"""
Expect input as shape [b, c, h, w]
"""
if not already_quantized:
x = self.quantize(x)
length, batch = x.shape
h = self.token_embeddings(x)
# prepend sos token
sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos
h = torch.cat([sos, h[:-1, :, :]], axis=0)
# add positional embeddings
positions = torch.arange(length, device=x.device).unsqueeze(-1)
h = h + self.position_embeddings(positions).expand_as(h)
# transformer
for layer in self.layers:
h = checkpoint(layer, h)
h = self.ln_f(h)
logits = self.head(h)
return logits, x

View File

@ -16,7 +16,6 @@ import models.archs.SRResNet_arch as SRResNet_arch
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.feature_arch as feature_arch
import models.archs.panet.panet as panet
import models.archs.rcan as rcan
from models.archs import srg2_classic
from models.archs.biggan.biggan_discriminator import BigGanDiscriminator
@ -69,12 +68,6 @@ def define_G(opt, opt_net, scale=None):
opt_net['n_colors'] = 3
args_obj = munchify(opt_net)
netG = rcan.RCAN(args_obj)
elif which_model == 'panet':
#args: n_resblocks, res_scale, scale, n_feats
opt_net['rgb_range'] = 255
opt_net['n_colors'] = 3
args_obj = munchify(opt_net)
netG = panet.PANET(args_obj)
elif which_model == "ConfigurableSwitchedResidualGenerator2":
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],
@ -158,10 +151,9 @@ def define_G(opt, opt_net, scale=None):
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
initial_conv_stride=opt_net['initial_stride'])
elif which_model == 'mdcn':
from models.archs.mdcn.mdcn import MDCN
args = munchify({'scale': opt_net['scale'], 'n_colors': 3, 'rgb_range': 1.0})
netG = MDCN(args)
elif which_model == 'igpt2':
from models.archs.transformers.igpt.gpt2 import iGPT2
netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG

View File

@ -20,6 +20,9 @@ def create_injector(opt_inject, env):
elif 'stereoscopic_' in type:
from models.steps.stereoscopic import create_stereoscopic_injector
return create_stereoscopic_injector(opt_inject, env)
elif 'igpt' in type:
from models.archs.transformers.igpt import gpt2
return gpt2.create_injector(opt_inject, env)
elif type == 'generator':
return ImageGeneratorInjector(opt_inject, env)
elif type == 'discriminator':

View File

@ -17,6 +17,8 @@ def create_loss(opt_loss, env):
elif 'stylegan2_' in type:
from models.archs.stylegan import create_stylegan2_loss
return create_stylegan2_loss(opt_loss, env)
elif type == 'crossentropy':
return CrossEntropy(opt_loss, env)
elif type == 'pix':
return PixLoss(opt_loss, env)
elif type == 'direct':
@ -89,6 +91,18 @@ def get_basic_criterion_for_name(name, device):
raise NotImplementedError
class CrossEntropy(ConfigurableLoss):
def __init__(self, opt, env):
super().__init__(opt, env)
self.opt = opt
self.ce = nn.CrossEntropyLoss()
def forward(self, _, state):
labels = state[self.opt['labels']]
logits = state[self.opt['logits']]
return self.ce(logits.view(-1, logits.size(-1)), labels.view(-1))
class PixLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(PixLoss, self).__init__(opt, env)