150 lines
4.7 KiB
Python
150 lines
4.7 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from trainer.injectors import Injector
|
|
from utils.util import checkpoint
|
|
|
|
|
|
def create_injector(opt, env):
|
|
type = opt['type']
|
|
if type == 'igpt_resolve':
|
|
return ResolveInjector(opt, env)
|
|
return None
|
|
|
|
|
|
class ResolveInjector(Injector):
|
|
def __init__(self, opt, env):
|
|
super().__init__(opt, env)
|
|
self.gen = opt['generator']
|
|
self.samples = opt['num_samples']
|
|
self.temperature = opt['temperature']
|
|
|
|
def forward(self, state):
|
|
gen = self.env['generators'][self.opt['generator']].module
|
|
img = state[self.input]
|
|
b, c, h, w = img.shape
|
|
qimg = gen.quantize(img)
|
|
s, b = qimg.shape
|
|
qimg = qimg[:s//2, :]
|
|
output = qimg.repeat(1, self.samples)
|
|
|
|
pad = torch.zeros(1, self.samples, dtype=torch.long).cuda() # to pad prev output
|
|
with torch.no_grad():
|
|
for _ in range(s//2):
|
|
logits, _ = gen(torch.cat((output, pad), dim=0), already_quantized=True)
|
|
logits = logits[-1, :, :] / self.temperature
|
|
probs = F.softmax(logits, dim=-1)
|
|
pred = torch.multinomial(probs, num_samples=1).transpose(1, 0)
|
|
output = torch.cat((output, pred), dim=0)
|
|
output = gen.unquantize(output.reshape(h, w, -1))
|
|
return {self.output: output.permute(2,3,0,1).contiguous()}
|
|
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, embed_dim, num_heads):
|
|
super(Block, self).__init__()
|
|
self.ln_1 = nn.LayerNorm(embed_dim)
|
|
self.ln_2 = nn.LayerNorm(embed_dim)
|
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(embed_dim, embed_dim * 4),
|
|
nn.GELU(),
|
|
nn.Linear(embed_dim * 4, embed_dim),
|
|
)
|
|
|
|
def forward(self, x):
|
|
attn_mask = torch.full(
|
|
(len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
|
|
)
|
|
attn_mask = torch.triu(attn_mask, diagonal=1)
|
|
|
|
x = self.ln_1(x)
|
|
a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
|
|
x = x + a
|
|
m = self.mlp(self.ln_2(x))
|
|
x = x + m
|
|
return x
|
|
|
|
|
|
class iGPT2(nn.Module):
|
|
def __init__(
|
|
self, embed_dim, num_heads, num_layers, num_positions, num_vocab, centroids_file
|
|
):
|
|
super().__init__()
|
|
|
|
self.centroids = nn.Parameter(
|
|
torch.from_numpy(np.load(centroids_file)), requires_grad=False
|
|
)
|
|
self.embed_dim = embed_dim
|
|
|
|
# start of sequence token
|
|
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
|
|
nn.init.normal_(self.sos)
|
|
|
|
self.token_embeddings = nn.Embedding(num_vocab, embed_dim)
|
|
self.position_embeddings = nn.Embedding(num_positions, embed_dim)
|
|
|
|
self.layers = nn.ModuleList()
|
|
for _ in range(num_layers):
|
|
self.layers.append(Block(embed_dim, num_heads))
|
|
|
|
self.ln_f = nn.LayerNorm(embed_dim)
|
|
self.head = nn.Linear(embed_dim, num_vocab, bias=False)
|
|
self.clf_head = nn.Linear(embed_dim, 10) # Fixed num_classes, this is not a classifier.
|
|
|
|
def squared_euclidean_distance(self, a, b):
|
|
b = torch.transpose(b, 0, 1)
|
|
a2 = torch.sum(torch.square(a), dim=1, keepdims=True)
|
|
b2 = torch.sum(torch.square(b), dim=0, keepdims=True)
|
|
ab = torch.matmul(a, b)
|
|
d = a2 - 2 * ab + b2
|
|
return d
|
|
|
|
def quantize(self, x):
|
|
b, c, h, w = x.shape
|
|
# [B, C, H, W] => [B, H, W, C]
|
|
x = x.permute(0, 2, 3, 1).contiguous()
|
|
x = x.view(-1, c) # flatten to pixels
|
|
d = self.squared_euclidean_distance(x, self.centroids)
|
|
x = torch.argmin(d, 1)
|
|
x = x.view(b, h, w)
|
|
|
|
# Reshape output to [seq_len, batch].
|
|
x = x.view(x.shape[0], -1) # flatten images into sequences
|
|
x = x.transpose(0, 1).contiguous() # to shape [seq len, batch]
|
|
return x
|
|
|
|
def unquantize(self, x):
|
|
return self.centroids[x]
|
|
|
|
def forward(self, x, already_quantized=False):
|
|
"""
|
|
Expect input as shape [b, c, h, w]
|
|
"""
|
|
|
|
if not already_quantized:
|
|
x = self.quantize(x)
|
|
length, batch = x.shape
|
|
|
|
h = self.token_embeddings(x)
|
|
|
|
# prepend sos token
|
|
sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos
|
|
h = torch.cat([sos, h[:-1, :, :]], axis=0)
|
|
|
|
# add positional embeddings
|
|
positions = torch.arange(length, device=x.device).unsqueeze(-1)
|
|
h = h + self.position_embeddings(positions).expand_as(h)
|
|
|
|
# transformer
|
|
for layer in self.layers:
|
|
h = checkpoint(layer, h)
|
|
|
|
h = self.ln_f(h)
|
|
|
|
logits = self.head(h)
|
|
|
|
return logits, x
|
|
|