1
0
forked from mrq/DL-Art-School
DL-Art-School/codes/models/transformers/igpt/gpt2.py

150 lines
4.7 KiB
Python
Raw Normal View History

2020-12-03 22:32:21 +00:00
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
2020-12-18 16:18:34 +00:00
from trainer.injectors import Injector
2020-12-03 22:32:21 +00:00
from utils.util import checkpoint
def create_injector(opt, env):
type = opt['type']
if type == 'igpt_resolve':
return ResolveInjector(opt, env)
return None
class ResolveInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.gen = opt['generator']
self.samples = opt['num_samples']
self.temperature = opt['temperature']
def forward(self, state):
gen = self.env['generators'][self.opt['generator']].module
img = state[self.input]
b, c, h, w = img.shape
qimg = gen.quantize(img)
s, b = qimg.shape
qimg = qimg[:s//2, :]
output = qimg.repeat(1, self.samples)
pad = torch.zeros(1, self.samples, dtype=torch.long).cuda() # to pad prev output
with torch.no_grad():
for _ in range(s//2):
logits, _ = gen(torch.cat((output, pad), dim=0), already_quantized=True)
logits = logits[-1, :, :] / self.temperature
probs = F.softmax(logits, dim=-1)
pred = torch.multinomial(probs, num_samples=1).transpose(1, 0)
output = torch.cat((output, pred), dim=0)
output = gen.unquantize(output.reshape(h, w, -1))
return {self.output: output.permute(2,3,0,1).contiguous()}
class Block(nn.Module):
def __init__(self, embed_dim, num_heads):
super(Block, self).__init__()
self.ln_1 = nn.LayerNorm(embed_dim)
self.ln_2 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Linear(embed_dim * 4, embed_dim),
)
def forward(self, x):
attn_mask = torch.full(
(len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
)
attn_mask = torch.triu(attn_mask, diagonal=1)
x = self.ln_1(x)
a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x
class iGPT2(nn.Module):
def __init__(
self, embed_dim, num_heads, num_layers, num_positions, num_vocab, centroids_file
):
super().__init__()
self.centroids = nn.Parameter(
torch.from_numpy(np.load(centroids_file)), requires_grad=False
)
self.embed_dim = embed_dim
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
self.token_embeddings = nn.Embedding(num_vocab, embed_dim)
self.position_embeddings = nn.Embedding(num_positions, embed_dim)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(Block(embed_dim, num_heads))
self.ln_f = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_vocab, bias=False)
self.clf_head = nn.Linear(embed_dim, 10) # Fixed num_classes, this is not a classifier.
def squared_euclidean_distance(self, a, b):
b = torch.transpose(b, 0, 1)
a2 = torch.sum(torch.square(a), dim=1, keepdims=True)
b2 = torch.sum(torch.square(b), dim=0, keepdims=True)
ab = torch.matmul(a, b)
d = a2 - 2 * ab + b2
return d
def quantize(self, x):
b, c, h, w = x.shape
# [B, C, H, W] => [B, H, W, C]
x = x.permute(0, 2, 3, 1).contiguous()
x = x.view(-1, c) # flatten to pixels
d = self.squared_euclidean_distance(x, self.centroids)
x = torch.argmin(d, 1)
x = x.view(b, h, w)
# Reshape output to [seq_len, batch].
x = x.view(x.shape[0], -1) # flatten images into sequences
x = x.transpose(0, 1).contiguous() # to shape [seq len, batch]
return x
def unquantize(self, x):
return self.centroids[x]
def forward(self, x, already_quantized=False):
"""
Expect input as shape [b, c, h, w]
"""
if not already_quantized:
x = self.quantize(x)
length, batch = x.shape
h = self.token_embeddings(x)
# prepend sos token
sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos
h = torch.cat([sos, h[:-1, :, :]], axis=0)
# add positional embeddings
positions = torch.arange(length, device=x.device).unsqueeze(-1)
h = h + self.position_embeddings(positions).expand_as(h)
# transformer
for layer in self.layers:
h = checkpoint(layer, h)
h = self.ln_f(h)
logits = self.head(h)
return logits, x