from functools import partial from itertools import islice, cycle import torch from torch import nn, einsum import torch.nn.functional as F from einops import rearrange from models.lucidrains.dalle.reversible import ReversibleSequence, SequentialSequence from models.lucidrains.dalle.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention from rotary_embedding_torch import RotaryEmbedding, broadcat from g_mlp_pytorch import gMLPBlock # helpers def exists(val): return val is not None def default(val, d): return val if exists(val) else d def cast_tuple(val, depth = 1): if isinstance(val, list): val = tuple(val) return val if isinstance(val, tuple) else (val,) * depth # classes class DivideMax(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): maxes = x.amax(dim = self.dim, keepdim = True).detach() return x / maxes # https://arxiv.org/abs/2103.17239 class LayerScale(nn.Module): def __init__(self, dim, depth, fn): super().__init__() if depth <= 18: init_eps = 0.1 elif depth > 18 and depth <= 24: init_eps = 1e-5 else: init_eps = 1e-6 scale = torch.zeros(1, 1, dim).fill_(init_eps) self.scale = nn.Parameter(scale) self.fn = fn def forward(self, x, **kwargs): return self.fn(x, **kwargs) * self.scale # layer norm class PreNorm(nn.Module): def __init__(self, dim, fn, sandwich = False): super().__init__() self.norm = nn.LayerNorm(dim) self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() self.fn = fn def forward(self, x, **kwargs): x = self.norm(x) x = self.fn(x, **kwargs) return self.norm_out(x) # feed forward class GEGLU(nn.Module): def forward(self, x): x, gates = x.chunk(2, dim = -1) return x * F.gelu(gates) class FeedForward(nn.Module): def __init__(self, dim, dropout = 0., mult = 4.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult * 2), GEGLU(), nn.Dropout(dropout), nn.Linear(dim * mult, dim) ) def forward(self, x): return self.net(x) # token shift classes class PreShiftToken(nn.Module): def __init__(self, fn, image_size, seq_len): super().__init__() self.fn = fn self.image_size = image_size self.seq_len = seq_len def forward(self, x, **kwargs): n = x.shape[1] seq_len, image_size = self.seq_len, self.image_size img_seq_len = image_size ** 2 text_len = seq_len - img_seq_len + 1 padding = seq_len - n + 1 # get text and image tokens x_text, x_img = x[:, :text_len], x[:, text_len:] x_img = F.pad(x_img, (0, 0, 0, padding)) x_img = rearrange(x_img, 'b (h w) d -> b h w d', h = image_size) # shift 1 from the left for text tokens x_text_shift, x_text_pass = x_text.chunk(2, dim = -1) x_text_shift = F.pad(x_text_shift, (0, 0, 1, -1)) x_text = torch.cat((x_text_shift, x_text_pass), dim = -1) # shift from top, left for image tokens x_img_shift_top, x_img_shift_left, *x_img_pass = x_img.chunk(4, dim = -1) x_img_shift_left = F.pad(x_img_shift_left, (0, 0, 1, -1)) x_img_shift_top = F.pad(x_img_shift_top, (0, 0, 0, 0, 1, -1)) x_img = torch.cat((x_img_shift_top, x_img_shift_left, *x_img_pass), dim = -1) # merge text and image sequence back together x_img = rearrange(x_img, 'b h w d -> b (h w) d') x = torch.cat((x_text, x_img[:, :-padding]), dim = 1) return self.fn(x, **kwargs) # main transformer class class Transformer(nn.Module): def __init__( self, *, dim, depth, seq_len, reversible = False, causal = True, heads = 8, dim_head = 64, ff_mult = 4, attn_dropout = 0., ff_dropout = 0., attn_types = None, image_fmap_size = None, oned_fmap_size = None, sparse_attn = False, stable = False, sandwich_norm = False, shift_tokens = False, rotary_emb = True ): super().__init__() layers = nn.ModuleList([]) sparse_layer = cast_tuple(sparse_attn, depth) attn_types = default(attn_types, ('full',)) attn_types = cast_tuple(attn_types) attn_type_layer = islice(cycle(attn_types), depth) for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer): if attn_type == 'full': attn_class = partial(Attention, stable = stable) elif attn_type == 'sparse': attn_class = SparseAttention elif attn_type == 'axial_row': attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable) elif attn_type == 'axial_col': attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable) elif attn_type == 'conv_like': attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable) elif attn_type == 'mlp': attn_class = partial(gMLPBlock, seq_len = seq_len) else: raise ValueError(f'attention type "{attn_type}" is not valid') if attn_type != 'mlp': attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) else: attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4) ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout) if shift_tokens: attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff)) layers.append(nn.ModuleList([ LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)), LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm)) ])) execute_type = ReversibleSequence if reversible else SequentialSequence route_attn = ((True, False),) * depth attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn} self.layers = execute_type(layers, args_route = attn_route_map) # generate positional embeddings for rotary pos_emb = None if rotary_emb: assert 'mlp' not in attn_types, 'you cannot use gMLPs if rotary embedding is turned on' rot_dim = dim_head // 3 img_seq_len = (image_fmap_size ** 2) if image_fmap_size is not None else oned_fmap_size text_len = seq_len - img_seq_len + 1 text_pos_emb = RotaryEmbedding(dim = rot_dim) img_axial_pos_emb = RotaryEmbedding(dim = rot_dim, freqs_for = 'pixel') text_freqs = text_pos_emb(torch.arange(text_len)) img_to_text_freqs = text_pos_emb(torch.full((img_seq_len,), 8192)) # image is given a position far away from text text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim = 0) img_freqs_axial = img_axial_pos_emb(torch.linspace(-1, 1, steps = image_fmap_size if image_fmap_size is not None else oned_fmap_size)) img_freqs = broadcat((rearrange(img_freqs_axial, 'i d -> i () d'), rearrange(img_freqs_axial, 'j d -> () j d')), dim = -1) img_freqs = rearrange(img_freqs, 'h w d -> (h w) d') text_axial_freqs = img_axial_pos_emb(torch.full((text_len,), -10.)) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1] text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim = -1) img_freqs = torch.cat((text_axial_freqs, img_freqs), dim = 0) pos_emb = torch.cat((text_freqs, img_freqs), dim = -1) pos_emb = rearrange(pos_emb, 'n d -> () n d') self.register_buffer('pos_emb', pos_emb) def forward(self, x, **kwargs): return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs)