forked from mrq/DL-Art-School
iGPT support!
Sweeeeet
This commit is contained in:
parent
c18adbd606
commit
06d1c62c5a
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -5,6 +5,7 @@ datasets/*
|
|||
options/*
|
||||
codes/*.txt
|
||||
codes/wandb/*
|
||||
data/*
|
||||
.vscode
|
||||
|
||||
*.html
|
||||
|
|
|
@ -45,6 +45,8 @@ def create_dataset(dataset_opt):
|
|||
from data.stylegan2_dataset import Stylegan2Dataset as D
|
||||
elif mode == 'imagefolder':
|
||||
from data.image_folder_dataset import ImageFolderDataset as D
|
||||
elif mode == 'torch_dataset':
|
||||
from data.torch_dataset import TorchDataset as D
|
||||
else:
|
||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||
dataset = D(dataset_opt)
|
||||
|
|
31
codes/data/torch_dataset.py
Normal file
31
codes/data/torch_dataset.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import torchvision.transforms as T
|
||||
from torchvision import datasets
|
||||
|
||||
# Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer.
|
||||
class TorchDataset(Dataset):
|
||||
def __init__(self, opt):
|
||||
DATASET_MAP = {
|
||||
"mnist": datasets.MNIST,
|
||||
"fmnist": datasets.FashionMNIST,
|
||||
"cifar10": datasets.CIFAR10,
|
||||
}
|
||||
transforms = []
|
||||
if opt['flip']:
|
||||
transforms.append(T.RandomHorizontalFlip())
|
||||
if opt['crop_sz']:
|
||||
transforms.append(T.RandomCrop(opt['crop_sz'], padding=opt['padding'], padding_mode="reflect"))
|
||||
transforms.append(T.ToTensor())
|
||||
transforms = T.Compose(transforms)
|
||||
is_for_training = opt['test'] if 'test' in opt.keys() else True
|
||||
self.dataset = DATASET_MAP[opt['dataset']](opt['datapath'], train=is_for_training, download=True, transform=transforms)
|
||||
self.len = opt['fixed_len'] if 'fixed_len' in opt.keys() else len(self.dataset)
|
||||
|
||||
def __getitem__(self, item):
|
||||
underlying_item = self.dataset[item][0]
|
||||
return {'LQ': underlying_item, 'GT': underlying_item,
|
||||
'LQ_path': str(item), 'GT_path': str(item)}
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
150
codes/models/archs/transformers/igpt/gpt2.py
Normal file
150
codes/models/archs/transformers/igpt/gpt2.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import torchvision
|
||||
from models.steps.injectors import Injector
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
def create_injector(opt, env):
|
||||
type = opt['type']
|
||||
if type == 'igpt_resolve':
|
||||
return ResolveInjector(opt, env)
|
||||
return None
|
||||
|
||||
|
||||
class ResolveInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.gen = opt['generator']
|
||||
self.samples = opt['num_samples']
|
||||
self.temperature = opt['temperature']
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']].module
|
||||
img = state[self.input]
|
||||
b, c, h, w = img.shape
|
||||
qimg = gen.quantize(img)
|
||||
s, b = qimg.shape
|
||||
qimg = qimg[:s//2, :]
|
||||
output = qimg.repeat(1, self.samples)
|
||||
|
||||
pad = torch.zeros(1, self.samples, dtype=torch.long).cuda() # to pad prev output
|
||||
with torch.no_grad():
|
||||
for _ in range(s//2):
|
||||
logits, _ = gen(torch.cat((output, pad), dim=0), already_quantized=True)
|
||||
logits = logits[-1, :, :] / self.temperature
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
pred = torch.multinomial(probs, num_samples=1).transpose(1, 0)
|
||||
output = torch.cat((output, pred), dim=0)
|
||||
output = gen.unquantize(output.reshape(h, w, -1))
|
||||
return {self.output: output.permute(2,3,0,1).contiguous()}
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads):
|
||||
super(Block, self).__init__()
|
||||
self.ln_1 = nn.LayerNorm(embed_dim)
|
||||
self.ln_2 = nn.LayerNorm(embed_dim)
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(embed_dim, embed_dim * 4),
|
||||
nn.GELU(),
|
||||
nn.Linear(embed_dim * 4, embed_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
attn_mask = torch.full(
|
||||
(len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
|
||||
)
|
||||
attn_mask = torch.triu(attn_mask, diagonal=1)
|
||||
|
||||
x = self.ln_1(x)
|
||||
a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
|
||||
x = x + a
|
||||
m = self.mlp(self.ln_2(x))
|
||||
x = x + m
|
||||
return x
|
||||
|
||||
|
||||
class iGPT2(nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim, num_heads, num_layers, num_positions, num_vocab, centroids_file
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.centroids = nn.Parameter(
|
||||
torch.from_numpy(np.load(centroids_file)), requires_grad=False
|
||||
)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
# start of sequence token
|
||||
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
|
||||
nn.init.normal_(self.sos)
|
||||
|
||||
self.token_embeddings = nn.Embedding(num_vocab, embed_dim)
|
||||
self.position_embeddings = nn.Embedding(num_positions, embed_dim)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(num_layers):
|
||||
self.layers.append(Block(embed_dim, num_heads))
|
||||
|
||||
self.ln_f = nn.LayerNorm(embed_dim)
|
||||
self.head = nn.Linear(embed_dim, num_vocab, bias=False)
|
||||
self.clf_head = nn.Linear(embed_dim, 10) # Fixed num_classes, this is not a classifier.
|
||||
|
||||
def squared_euclidean_distance(self, a, b):
|
||||
b = torch.transpose(b, 0, 1)
|
||||
a2 = torch.sum(torch.square(a), dim=1, keepdims=True)
|
||||
b2 = torch.sum(torch.square(b), dim=0, keepdims=True)
|
||||
ab = torch.matmul(a, b)
|
||||
d = a2 - 2 * ab + b2
|
||||
return d
|
||||
|
||||
def quantize(self, x):
|
||||
b, c, h, w = x.shape
|
||||
# [B, C, H, W] => [B, H, W, C]
|
||||
x = x.permute(0, 2, 3, 1).contiguous()
|
||||
x = x.view(-1, c) # flatten to pixels
|
||||
d = self.squared_euclidean_distance(x, self.centroids)
|
||||
x = torch.argmin(d, 1)
|
||||
x = x.view(b, h, w)
|
||||
|
||||
# Reshape output to [seq_len, batch].
|
||||
x = x.view(x.shape[0], -1) # flatten images into sequences
|
||||
x = x.transpose(0, 1).contiguous() # to shape [seq len, batch]
|
||||
return x
|
||||
|
||||
def unquantize(self, x):
|
||||
return self.centroids[x]
|
||||
|
||||
def forward(self, x, already_quantized=False):
|
||||
"""
|
||||
Expect input as shape [b, c, h, w]
|
||||
"""
|
||||
|
||||
if not already_quantized:
|
||||
x = self.quantize(x)
|
||||
length, batch = x.shape
|
||||
|
||||
h = self.token_embeddings(x)
|
||||
|
||||
# prepend sos token
|
||||
sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos
|
||||
h = torch.cat([sos, h[:-1, :, :]], axis=0)
|
||||
|
||||
# add positional embeddings
|
||||
positions = torch.arange(length, device=x.device).unsqueeze(-1)
|
||||
h = h + self.position_embeddings(positions).expand_as(h)
|
||||
|
||||
# transformer
|
||||
for layer in self.layers:
|
||||
h = checkpoint(layer, h)
|
||||
|
||||
h = self.ln_f(h)
|
||||
|
||||
logits = self.head(h)
|
||||
|
||||
return logits, x
|
||||
|
|
@ -16,7 +16,6 @@ import models.archs.SRResNet_arch as SRResNet_arch
|
|||
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
||||
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
||||
import models.archs.feature_arch as feature_arch
|
||||
import models.archs.panet.panet as panet
|
||||
import models.archs.rcan as rcan
|
||||
from models.archs import srg2_classic
|
||||
from models.archs.biggan.biggan_discriminator import BigGanDiscriminator
|
||||
|
@ -69,12 +68,6 @@ def define_G(opt, opt_net, scale=None):
|
|||
opt_net['n_colors'] = 3
|
||||
args_obj = munchify(opt_net)
|
||||
netG = rcan.RCAN(args_obj)
|
||||
elif which_model == 'panet':
|
||||
#args: n_resblocks, res_scale, scale, n_feats
|
||||
opt_net['rgb_range'] = 255
|
||||
opt_net['n_colors'] = 3
|
||||
args_obj = munchify(opt_net)
|
||||
netG = panet.PANET(args_obj)
|
||||
elif which_model == "ConfigurableSwitchedResidualGenerator2":
|
||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
||||
switch_reductions=opt_net['switch_reductions'],
|
||||
|
@ -158,10 +151,9 @@ def define_G(opt, opt_net, scale=None):
|
|||
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
||||
initial_conv_stride=opt_net['initial_stride'])
|
||||
elif which_model == 'mdcn':
|
||||
from models.archs.mdcn.mdcn import MDCN
|
||||
args = munchify({'scale': opt_net['scale'], 'n_colors': 3, 'rgb_range': 1.0})
|
||||
netG = MDCN(args)
|
||||
elif which_model == 'igpt2':
|
||||
from models.archs.transformers.igpt.gpt2 import iGPT2
|
||||
netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
return netG
|
||||
|
|
|
@ -20,6 +20,9 @@ def create_injector(opt_inject, env):
|
|||
elif 'stereoscopic_' in type:
|
||||
from models.steps.stereoscopic import create_stereoscopic_injector
|
||||
return create_stereoscopic_injector(opt_inject, env)
|
||||
elif 'igpt' in type:
|
||||
from models.archs.transformers.igpt import gpt2
|
||||
return gpt2.create_injector(opt_inject, env)
|
||||
elif type == 'generator':
|
||||
return ImageGeneratorInjector(opt_inject, env)
|
||||
elif type == 'discriminator':
|
||||
|
|
|
@ -17,6 +17,8 @@ def create_loss(opt_loss, env):
|
|||
elif 'stylegan2_' in type:
|
||||
from models.archs.stylegan import create_stylegan2_loss
|
||||
return create_stylegan2_loss(opt_loss, env)
|
||||
elif type == 'crossentropy':
|
||||
return CrossEntropy(opt_loss, env)
|
||||
elif type == 'pix':
|
||||
return PixLoss(opt_loss, env)
|
||||
elif type == 'direct':
|
||||
|
@ -89,6 +91,18 @@ def get_basic_criterion_for_name(name, device):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class CrossEntropy(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.ce = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, _, state):
|
||||
labels = state[self.opt['labels']]
|
||||
logits = state[self.opt['logits']]
|
||||
return self.ce(logits.view(-1, logits.size(-1)), labels.view(-1))
|
||||
|
||||
|
||||
class PixLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super(PixLoss, self).__init__(opt, env)
|
||||
|
|
Loading…
Reference in New Issue
Block a user