Remove obsolete lucidrains DALLE stuff, re-create it in a dedicated folder

This commit is contained in:
James Betker 2021-12-22 13:44:02 -07:00
parent a42b94ab72
commit 09f7f3e615
10 changed files with 632 additions and 771 deletions

View File

@ -1,242 +0,0 @@
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from munch import munchify
from models.gpt_voice.lucidrains_gpt import Transformer
from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.text import symbols, sequence_to_text
from trainer.networks import register_model
from utils.util import opt_get
class ResBlock(nn.Module):
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
nn.BatchNorm1d(chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
nn.BatchNorm1d(chan)
)
def forward(self, x):
return F.relu(self.net(x) + x)
class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=7, padding=3),
ResBlock(channels//4),
ResBlock(channels//4),
nn.Conv1d(channels//4, channels//2, kernel_size=5, stride=2, padding=2),
nn.BatchNorm1d(channels//2),
nn.ReLU(),
ResBlock(channels//2),
ResBlock(channels//2),
ResBlock(channels//2),
nn.Conv1d(channels//2, channels, kernel_size=5, stride=2, padding=2),
ResBlock(channels),
ResBlock(channels),
ResBlock(channels)
)
def forward(self, x):
return self.encoder(x)
class GptAsr(nn.Module):
NUMBER_SYMBOLS = len(symbols)
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_frames=1000):
super().__init__()
self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding.
self.max_symbols_per_phrase = max_symbols_per_phrase
self.model_dim = model_dim
self.max_mel_frames = self.max_mel_frames
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
self.mel_encoder = MelEncoder(model_dim)
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim)
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=2 + self.max_symbols_per_phrase + self.max_mel_frames, heads=heads,
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.max_mel_frames)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
def get_logits(self, mel_inputs, text_targets):
# Pad front and back. Pad at front is the "START" token.
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
text_emb = self.text_embedding(text_targets)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
emb = torch.cat([mel_emb, text_emb], dim=1)
enc = self.gpt(emb)
text_logits = self.final_norm(enc[:, self.max_mel_frames:])
text_logits = self.text_head(text_logits)
text_logits = text_logits.permute(0,2,1)
return text_logits
def forward(self, mel_inputs, text_targets):
text_logits = self.get_logits(mel_inputs, text_targets)
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
return loss_text.mean(), text_logits
def inference_beam_topk(self, mel, fn='inference_beam'):
def topk_sampler(distribution, k):
return torch.topk(distribution, k=k, dim=-1)
return getattr(self, fn)(mel, topk_sampler)
def inference_beam_sampled(self, mel, fn='inference_beam'):
def multinomial_sampler(distribution, k):
indices = torch.multinomial(distribution, num_samples=k, replacement=False)
values = torch.gather(distribution, dim=1, index=indices)
class container:
def __init__(self, i, v):
self.indices = i
self.values = v
return container(indices, values)
return getattr(self, fn)(mel, multinomial_sampler)
def inference_beam(self, mel_inputs, sampler_fn):
beam_width = 16
temperature = .8
b, _, s = mel_inputs.shape
assert b == 1 # Beam search only works on batches of one.
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
probabilities = torch.ones((b,), device=mel_emb.device)
while text_seq.shape[-1] < self.max_symbols_per_phrase:
text_emb = self.text_embedding(text_seq)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
if text_emb.shape[0] != mel_emb.shape[0]:
mel_emb = mel_emb.repeat(text_emb.shape[0], 1, 1)
emb = torch.cat([mel_emb, text_emb], dim=1)
enc = self.gpt(emb)
text_logits = self.final_norm(enc[:, mel_emb.shape[1]:])
text_logits = self.text_head(text_logits)
topk = sampler_fn(F.softmax(temperature * text_logits[:, -1], dim=-1), k=beam_width)
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
probabilities, sort_indices = torch.sort(probabilities, descending=True)
probabilities = probabilities[:beam_width]
text_seq = text_seq.repeat_interleave(beam_width, dim=0)
codes = topk.indices.flatten()
text_seq = torch.cat([text_seq, codes.unsqueeze(1)], dim=1)
text_seq = text_seq[sort_indices]
text_seq = text_seq[:beam_width]
# PAD doubles as a stop token. PAD=0.
if torch.all(torch.any(text_seq == 0, dim=1)):
break
if text_seq.shape[1] >= self.max_mel_frames:
print("Warning! Encountered frame limit before a pad token. Output is likely wrong.")
return text_seq
def inference_beam_opt(self, mel_inputs, sampler_fn):
beam_width = 16
temperature = .8
b, _, s = mel_inputs.shape
assert b == 1 # Beam search only works on batches of one.
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
intermediates = []
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
probabilities = torch.ones((b,), device=mel_emb.device)
while text_seq.shape[-1] < self.max_symbols_per_phrase:
text_emb = self.text_embedding(text_seq)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
if text_emb.shape[0] != mel_emb.shape[0]:
mel_emb = mel_emb.repeat(text_emb.shape[0], 1, 1)
emb = torch.cat([mel_emb, text_emb], dim=1)
if len(intermediates) == 0:
enc, intermediates = self.gpt(emb, return_intermediates=True)
intermediates = [(i[0].repeat(beam_width, 1, 1),
i[1].repeat(beam_width, 1, 1)) for i in intermediates]
else:
enc, intermediates = self.gpt.infer_last_two(emb, intermediates)
text_logits = self.final_norm(enc[:, mel_emb.shape[1]:])
text_logits = self.text_head(text_logits)
topk = sampler_fn(F.softmax(temperature * text_logits[:, -1], dim=-1), k=beam_width)
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
probabilities, sort_indices = torch.sort(probabilities, descending=True)
probabilities = probabilities[:beam_width]
text_seq = text_seq.repeat_interleave(beam_width, dim=0)
codes = topk.indices.flatten()
text_seq = torch.cat([text_seq, codes.unsqueeze(1)], dim=1)
text_seq = text_seq[sort_indices]
text_seq = text_seq[:beam_width]
# PAD doubles as a stop token. PAD=0.
if torch.all(torch.any(text_seq == 0, dim=1)):
break
if text_seq.shape[1] >= self.max_mel_frames:
print("Warning! Encountered frame limit before a pad token. Output is likely wrong.")
return text_seq
@register_model
def register_gpt_asr(opt_net, opt):
return GptAsr(**opt_get(opt_net, ['kwargs'], {}))
# Quick script that loads a model and halves the number of layers, then saves that model.
def distill():
gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12)
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
rc = 0
i = 0
while i < len(gpt.gpt.layers.layers):
if rc % 2 != 0:
del gpt.gpt.layers.layers[i]
else:
i += 1
rc += 1
torch.save(gpt.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
if __name__ == '__main__':
gpt = GptAsr(max_symbols_per_phrase=100, max_mel_frames=200, layers=6, model_dim=256, heads=2).cuda()
#l = gpt(torch.randn(2,80,800),
# torch.randint(high=len(symbols), size=(2,180)))
with torch.no_grad():
t = torch.randn(1,80,800).cuda()
start = time()
s = gpt.inference_beam_topk(t)
print(time()-start)
start = time()
o = gpt.inference_beam_topk(t, fn='inference_beam_opt')
print(time()-start)

View File

@ -1,102 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from munch import munchify
from models.gpt_voice.lucidrains_gpt import Transformer
from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.text import symbols, sequence_to_text
from trainer.networks import register_model
from utils.util import opt_get
class ResBlock(nn.Module):
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
nn.BatchNorm1d(chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
nn.BatchNorm1d(chan)
)
def forward(self, x):
return F.relu(self.net(x) + x)
class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=7, padding=3),
ResBlock(channels//4),
ResBlock(channels//4),
nn.Conv1d(channels//4, channels//2, kernel_size=5, stride=2, padding=2),
nn.BatchNorm1d(channels//2),
nn.ReLU(),
ResBlock(channels//2),
ResBlock(channels//2),
ResBlock(channels//2),
nn.Conv1d(channels//2, channels, kernel_size=5, stride=2, padding=2),
ResBlock(channels),
ResBlock(channels),
ResBlock(channels)
)
def forward(self, x):
return self.encoder(x)
class GptSegmentor(nn.Module):
MAX_MEL_FRAMES = 2000 // 4
def __init__(self, layers=8, model_dim=512, heads=8):
super().__init__()
self.model_dim = model_dim
self.max_mel_frames = self.MAX_MEL_FRAMES
self.mel_encoder = MelEncoder(model_dim)
self.mel_pos_embedding = nn.Embedding(self.MAX_MEL_FRAMES, model_dim)
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=self.MAX_MEL_FRAMES, heads=heads,
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.MAX_MEL_FRAMES)
self.final_norm = nn.LayerNorm(model_dim)
self.start_head = nn.Linear(model_dim, 1)
self.stop_head = nn.Linear(model_dim, 1)
def forward(self, mel_inputs, start_labels=None, end_labels=None):
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
enc = self.gpt(mel_emb)
logits = self.final_norm(enc)
stop_logits = self.stop_head(logits)
start_logits = self.start_head(logits)
if start_labels is not None:
# Compute loss
start_loss = F.binary_cross_entropy_with_logits(start_logits.squeeze(-1), start_labels.float())
end_loss = F.binary_cross_entropy_with_logits(stop_logits.squeeze(-1), end_labels.float())
return start_loss.mean(), end_loss.mean()
else:
return start_logits, stop_logits
@register_model
def register_gpt_segmentor(opt_net, opt):
return GptSegmentor(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
gpt = GptSegmentor()
l = gpt(torch.randn(3,80,94),
torch.zeros(3,94))
print(l.shape)
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
#print(o.shape)

View File

@ -1,176 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from munch import munchify
from models.gpt_voice.lucidrains_gpt import Transformer
from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.text import symbols
from trainer.networks import register_model
from utils.util import opt_get
class GptTts(nn.Module):
MAX_SYMBOLS_PER_PHRASE = 200
NUMBER_SYMBOLS = len(symbols)
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
MEL_DICTIONARY_SIZE = 512+3
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
def __init__(self, layers=8, model_dim=512, heads=8):
super().__init__()
max_mel_frames = 900 * 1 // 4 # 900 is the max number of MEL frames. The VQVAE outputs 1/8 of the input mel as tokens.
self.model_dim = model_dim
self.max_mel_frames = max_mel_frames
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim)
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim)
self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim)
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=heads,
attn_dropout=.1, ff_dropout=.1)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
self.mel_head = nn.Linear(model_dim, self.MEL_DICTIONARY_SIZE)
def forward(self, text_inputs, text_lengths, mel_targets, output_lengths):
text_emb = self.text_embedding(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
mel_emb = self.mel_embedding(mel_targets)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_targets.shape[1], device=mel_targets.device))
emb = torch.cat([text_emb, mel_emb], dim=1)
enc = self.gpt(emb)
# Compute logits for text and mel heads
text_logits = self.final_norm(enc[:, :text_emb.shape[1]])
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
text_logits = self.text_head(text_logits)
mel_logits = self.mel_head(mel_logits)
# Compute loss
text_targets = text_inputs[:,1:]
text_logits = text_logits.permute(0,2,1)[:,:,:-1] # The last element of the logits is unneeded because the input to the transformer contains a <EOS> token for both text and mel.
loss_text = F.cross_entropy(text_logits, text_targets, reduction='none')
mel_targets = mel_targets[:,1:]
mel_logits = mel_logits.permute(0,2,1)[:,:,:-1]
loss_mel = F.cross_entropy(mel_logits, mel_targets, reduction='none')
# Fix up mel_logits so it can go into a VAE decoder as well.
mel_codes = torch.argmax(F.softmax(mel_logits, dim=1), dim=1)
mel_pad_mask = ~get_mask_from_lengths(output_lengths-1, mel_targets.shape[1])
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask, 0)
mel_codes = mel_codes[:,:-1] # Strip off <EOS> token too (or padding). The important part is that the output sequence length is identical to the VAE input.
extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD
mel_codes = mel_codes * extra_mask
# This class also returns the mel_targets for validation purposes. Format those.
mel_targets = mel_targets[:,:-1]
mel_targets = mel_targets * (mel_targets < self.MEL_DICTIONARY_SIZE-3)
return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets
def inference(self, text_inputs):
b, s = text_inputs.shape
text_emb = self.text_embedding(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(s, device=text_inputs.device))
mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
stop_encountered = torch.zeros((b,), device=text_emb.device)
while not torch.all(stop_encountered) and len(mel_seq) < self.max_mel_frames:
mel_emb = self.mel_embedding(mel_seq)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
emb = torch.cat([text_emb, mel_emb], dim=1)
enc = self.gpt(emb)
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
mel_logits = self.mel_head(mel_logits)
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
mel_seq = torch.cat([mel_seq, mel_codes[:, -1].unsqueeze(1)], dim=1)
stop_encountered = torch.logical_or(stop_encountered, mel_seq[:,-1] == self.MEL_STOP_TOKEN)
if len(mel_seq) >= self.max_mel_frames:
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
mel_seq = mel_seq[:, 1:-1] # Remove first and last tokens, which were artificially added for GPT
mel_seq = mel_seq * (mel_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens.
return mel_seq
def inference_beam_topk(self, text):
def topk_sampler(distribution, k):
return torch.topk(distribution, k=k, dim=-1)
return self.inference_beam(text, topk_sampler)
def inference_beam_sampled(self, text):
def multinomial_sampler(distribution, k):
indices = torch.multinomial(distribution, num_samples=k, replacement=False)
values = torch.gather(distribution, dim=1, index=indices)
class container:
def __init__(self, i, v):
self.indices = i
self.values = v
return container(indices, values)
return self.inference_beam(text, multinomial_sampler)
def inference_beam(self, text_inputs, sampler_fn):
beam_width = 16
temperature = .8
b, s = text_inputs.shape
assert b == 1 # Beam search only works on batches of one.
text_emb = self.text_embedding(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(s, device=text_inputs.device))
mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
probabilities = torch.ones((b,), device=text_emb.device)
while len(mel_seq) < self.max_mel_frames:
mel_emb = self.mel_embedding(mel_seq)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
if text_emb.shape[0] != mel_emb.shape[0]:
text_emb = text_emb.repeat(mel_emb.shape[0], 1, 1)
emb = torch.cat([text_emb, mel_emb], dim=1)
enc = self.gpt(emb)
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
mel_logits = self.mel_head(mel_logits)
topk = sampler_fn(F.softmax(temperature * mel_logits[:, -1], dim=-1), k=beam_width)
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
probabilities, sort_indices = torch.sort(probabilities, descending=True)
probabilities = probabilities[:beam_width]
mel_seq = mel_seq.repeat_interleave(beam_width, dim=0)
codes = topk.indices.flatten()
mel_seq = torch.cat([mel_seq, codes.unsqueeze(1)], dim=1)
mel_seq = mel_seq[sort_indices]
mel_seq = mel_seq[:beam_width]
if torch.all(torch.any(mel_seq == self.MEL_STOP_TOKEN, dim=1)):
break
if mel_seq.shape[1] >= self.max_mel_frames:
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
mel_seq = mel_seq[0, 1:-1].unsqueeze(0) # Pick most likely outcome, remove first and last tokens, which were artificially added for GPT
mel_seq = mel_seq * (mel_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens.
return mel_seq
@register_model
def register_gpt_tts(opt_net, opt):
return GptTts(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
gpt = GptTts()
l1, l2, i = gpt(torch.randint(high=24, size=(2,60)),
torch.tensor([55,58]),
torch.randint(high=512, size=(2,310)),
torch.tensor([300,305]))
print(i.shape)
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
#print(o.shape)

View File

@ -1,236 +0,0 @@
from inspect import isfunction
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# helpers
from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence
from utils.util import checkpoint
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, depth = 1):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else (val,) * depth
class DivideMax(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
maxes = x.amax(dim = self.dim, keepdim = True)
return x / maxes
# https://arxiv.org/abs/2103.17239
class LayerScale(nn.Module):
def __init__(self, dim, depth, fn):
super().__init__()
if depth <= 18:
init_eps = 0.1
elif depth > 18 and depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
class FeedForward(nn.Module):
def __init__(self, dim, dropout = 0., mult = 4.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x, only_last_two_elements=False):
if only_last_two_elements:
h = x[:, -2:]
h = self.net(h)
return torch.cat([x[:, :-2], h], dim=1)
else:
return self.net(x)
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
t = t / alpha
t = t - torch.amax(t, dim = dim, keepdim = True)
return (t * alpha).softmax(dim = dim)
# classes
class Attention(nn.Module):
def __init__(self, dim, seq_len, non_causal_sequence_partition = 0, heads = 8, dim_head = 64, dropout = 0., stable = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.seq_len = seq_len
self.scale = dim_head ** -0.5
self.stable = stable
self.non_causal_sequence_partition = non_causal_sequence_partition
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None, only_last_two_elements=False):
b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax if not self.stable else stable_softmax
# TODO: Q and V do not need to be recomputed for existing elements in intermediate_latents is specified. V would need to be cached though.
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
q = q * self.scale
if only_last_two_elements:
q = q[:, :, -2:]
assert not exists(mask) # Don't know how to resolve this (currently)
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
mask_value = max_neg_value(dots)
if exists(mask):
mask = rearrange(mask, 'b j -> b () () j')
dots.masked_fill_(~mask, mask_value)
del mask
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1)
if self.non_causal_sequence_partition > 0:
non_causal_mask = torch.ones((i, j), device=device)
non_causal_mask[:, :self.non_causal_sequence_partition] = 0
mask = mask * non_causal_mask
dots.masked_fill_(mask.bool(), mask_value)
attn = softmax(dots, dim=-1)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
seq_len,
reversible = False,
heads = 8,
dim_head = 64,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.,
sparse_attn = False,
stable = False,
non_causal_sequence_partition=0,
):
super().__init__()
layers = nn.ModuleList([])
sparse_layer = cast_tuple(sparse_attn, depth)
for ind, sparse_attn in zip(range(depth), sparse_layer):
attn = Attention(dim, stable=stable, non_causal_sequence_partition = non_causal_sequence_partition, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
layers.append(nn.ModuleList([
LayerScale(dim, ind + 1, PreNorm(dim, attn)),
LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout)))
]))
# TODO: Remove this nonsense. I don't want to support reversible sequences and this is just a mess.
execute_type = ReversibleSequence if reversible else SequentialSequence
route_attn = ((True, False),) * depth
attn_route_map = {'mask': route_attn}
self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True)
self.depth = depth
def forward(self, x, return_intermediates=False):
intermediates = []
for attn, ff in self.layers.layers:
x_ff = x + checkpoint(attn, x)
x = x_ff + ff(x_ff)
if return_intermediates:
intermediates.append((x_ff, x))
if return_intermediates:
return x, intermediates
else:
return x
def infer_last_two(self, x, prev_intermediates):
"""
Performs an forward pass only on the last two element in the given sequence (allowing it to attend to all other
elements). This is useful for faster autoregressive decoding.
The last two elements are important because in inference, the last element is the prediction candidate and the
second-to-last element is a newly selected element from the autoregressive searching process.
"""
assert(len(prev_intermediates) == self.depth)
new_intermediates = []
for (attn, ff), (int_ff, int_out) in zip(self.layers.layers, prev_intermediates):
x_ff = attn(x, only_last_two_elements=True)
# Note that (x) is now only the last two element in the set. Conjoin it with the int_ff latent to compute the norm.
x_ff = x + torch.cat([int_ff[:,:-1], x_ff], dim=1)
x = x_ff + ff(x_ff, only_last_two_elements=True)
new_intermediates.append((x_ff, x))
return x, new_intermediates

View File

View File

@ -0,0 +1 @@
# This directory contains some useful code from https://github.com/lucidrains/DALLE-pytorch/tree/main/dalle_pytorch

View File

@ -0,0 +1,384 @@
from inspect import isfunction
from math import ceil
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from rotary_embedding_torch import apply_rotary_emb
# helpers
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
t = t / alpha
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
return (t * alpha).softmax(dim = dim)
def apply_pos_emb(pos_emb, qkv):
n = qkv[0].shape[-2]
pos_emb = pos_emb[..., :n, :]
return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv))
# classes
class Attention(nn.Module):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.seq_len = seq_len
self.scale = dim_head ** -0.5
self.stable = stable
self.causal = causal
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None, rotary_pos_emb = None):
b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax if not self.stable else stable_softmax
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
q = q * self.scale
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
mask_value = max_neg_value(dots)
if exists(mask):
mask = rearrange(mask, 'b j -> b () () j')
dots.masked_fill_(~mask, mask_value)
del mask
if self.causal:
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)
attn = softmax(dots, dim=-1)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
class SparseConvCausalAttention(nn.Module):
def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
super().__init__()
assert kernel_size % 2 == 1, 'kernel size must be odd'
inner_dim = dim_head * heads
self.seq_len = seq_len
self.heads = heads
self.scale = dim_head ** -0.5
self.image_size = image_size
self.kernel_size = kernel_size
self.dilation = dilation
self.stable = stable
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None, rotary_pos_emb = None):
b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
softmax = torch.softmax if not self.stable else stable_softmax
img_seq_len = img_size ** 2
text_len = seq_len + 1 - img_seq_len
# padding
padding = seq_len - n + 1
mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())
x = F.pad(x, (0, 0, 0, padding), value = 0)
mask = mask[:, :text_len]
# derive query / keys / values
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
q *= self.scale
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
# text attention
dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
mask_value = max_neg_value(dots_text)
i, j = dots_text.shape[-2:]
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots_text.masked_fill_(text_causal_mask, mask_value)
attn_text = softmax(dots_text, dim = -1)
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
# image attention
effective_kernel_size = (kernel_size - 1) * dilation + 1
padding = effective_kernel_size // 2
k_img, v_img = map(lambda t: rearrange(t, 'b (h w) c -> b c h w', h = img_size), (k_img, v_img))
k_img, v_img = map(lambda t: F.unfold(t, kernel_size, padding = padding, dilation = dilation), (k_img, v_img))
k_img, v_img = map(lambda t: rearrange(t, 'b (d j) i -> b i j d', j = kernel_size ** 2), (k_img, v_img))
# let image attend to all of text
dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img)
dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text)
# calculate causal attention for local convolution
i, j = dots_image.shape[-2:]
img_seq = torch.arange(img_seq_len, device = device)
k_img_indices = rearrange(img_seq.float(), '(h w) -> () () h w', h = img_size)
k_img_indices = F.pad(k_img_indices, (padding,) * 4, value = img_seq_len) # padding set to be max, so it is never attended to
k_img_indices = F.unfold(k_img_indices, kernel_size, dilation = dilation)
k_img_indices = rearrange(k_img_indices, 'b j i -> b i j')
# mask image attention
q_img_indices = rearrange(img_seq, 'i -> () i ()')
causal_mask = q_img_indices < k_img_indices
# concat text mask with image causal mask
causal_mask = repeat(causal_mask, '() i j -> b i j', b = b * h)
mask = repeat(mask, 'b j -> (b h) i j', i = i, h = h)
mask = torch.cat((~mask, causal_mask), dim = -1)
# image can attend to all of text
dots = torch.cat((dots_image_to_text, dots_image), dim = -1)
dots.masked_fill_(mask, mask_value)
attn = softmax(dots, dim = -1)
# aggregate
attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:]
out_image_to_image = einsum('b i j, b i j d -> b i d', attn_image, v_img)
out_image_to_text = einsum('b i j, b j d -> b i d', attn_image_to_text, v_text)
out_image = out_image_to_image + out_image_to_text
# combine attended values for both text and image
out = torch.cat((out_text, out_image), dim = 1)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
out = self.to_out(out)
return out[:, :n]
# sparse axial causal attention
class SparseAxialCausalAttention(nn.Module):
def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
super().__init__()
assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)'
self.axis = axis
inner_dim = dim_head * heads
self.seq_len = seq_len
self.heads = heads
self.scale = dim_head ** -0.5
self.image_size = image_size
self.stable = stable
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None, rotary_pos_emb = None):
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
softmax = torch.softmax if not self.stable else stable_softmax
img_seq_len = img_size ** 2
text_len = seq_len + 1 - img_seq_len
# padding
padding = seq_len - n + 1
mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())
x = F.pad(x, (0, 0, 0, padding), value = 0)
mask = mask[:, :text_len]
# derive queries / keys / values
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
q *= self.scale
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
# text attention
dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
mask_value = max_neg_value(dots_text)
i, j = dots_text.shape[-2:]
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots_text.masked_fill_(text_causal_mask, mask_value)
attn_text = softmax(dots_text, dim = -1)
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
# image attention
split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c'
merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d'
# split out axis
q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img))
# similarity
dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img)
dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)
dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)
# mask so image has full attention to text, but causal along axis
bh, x, i, j = dots.shape
causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool()
causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x)
mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i)
mask = torch.cat((~mask, causal_mask), dim = -1)
dots.masked_fill_(mask, mask_value)
# attention.
attn = softmax(dots, dim = -1)
# aggregate
attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]
out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img)
out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text)
out_image = out_image_to_image + out_image_to_text
# merge back axis
out_image = rearrange(out_image, merge_axis_einops, x = img_size)
# combine attended values for both text and image
out = torch.cat((out_text, out_image), dim = 1)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
out = self.to_out(out)
return out[:, :n]
# microsoft sparse attention CUDA kernel
class SparseAttention(Attention):
def __init__(
self,
*args,
block_size = 16,
text_seq_len = 256,
num_random_blocks = None,
**kwargs
):
super().__init__(*args, **kwargs)
from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig
self.block_size = block_size
num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4)
global_block_indices = list(range(ceil(text_seq_len / block_size)))
self.attn_fn = SparseSelfAttention(
sparsity_config = VariableSparsityConfig(
num_heads = self.heads,
block = self.block_size,
num_random_blocks = num_random_blocks,
global_block_indices = global_block_indices,
attention = 'unidirectional' if self.causal else 'bidirectional'
),
max_seq_length = self.seq_len,
attn_mask_mode = 'add'
)
def forward(self, x, mask = None, rotary_pos_emb = None):
b, n, _, h, device = *x.shape, self.heads, x.device
remainder = n % self.block_size
mask = default(mask, lambda: torch.ones(b, n, device = device).bool())
if remainder > 0:
padding = self.block_size - remainder
x = F.pad(x, (0, 0, 0, padding), value = 0)
mask = F.pad(mask, (0, padding), value = False)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
key_pad_mask = None
if exists(mask):
key_pad_mask = ~mask
attn_mask = None
if self.causal:
i, j = q.shape[-2], k.shape[-2]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
attn_mask = torch.zeros(i, j, device = device).to(q)
mask_value = max_neg_value(q) / 2
attn_mask.masked_fill_(mask, mask_value)
out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out[:, :n]

View File

@ -1,12 +1,10 @@
import torch
import torch.nn as nn
from operator import itemgetter
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states
# for routing arguments into the functions of the reversible layer
from utils.util import checkpoint
def route_args(router, args, depth):
routed_args = [(dict(), dict()) for _ in range(depth)]
matched_keys = [key for key in args.keys() if key in router]
@ -126,25 +124,20 @@ class _ReversibleFunction(Function):
return dy, None, None
class SequentialSequence(nn.Module):
def __init__(self, layers, args_route = {}, layer_dropout = 0., checkpoint=False):
def __init__(self, layers, args_route = {}, layer_dropout = 0.):
super().__init__()
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
self.layers = layers
self.args_route = args_route
self.layer_dropout = layer_dropout
self.checkpoint = checkpoint
def forward(self, x, **kwargs):
args = route_args(self.args_route, kwargs, len(self.layers))
layers_and_args = list(zip(self.layers, args))
for (f, g), (f_args, g_args) in layers_and_args:
if self.checkpoint:
x = x + f(x, **f_args)
x = x + g(x, **g_args)
else:
x = x + checkpoint(f, x, **f_args)
x = x + checkpoint(g, x, **g_args)
x = x + f(x, **f_args)
x = x + g(x, **g_args)
return x
class ReversibleSequence(nn.Module):

View File

@ -0,0 +1,231 @@
from functools import partial
from itertools import islice, cycle
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
from models.lucidrains.dalle.reversible import ReversibleSequence, SequentialSequence
from models.lucidrains.dalle.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention
from rotary_embedding_torch import RotaryEmbedding, broadcat
from g_mlp_pytorch import gMLPBlock
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, depth = 1):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else (val,) * depth
# classes
class DivideMax(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
maxes = x.amax(dim = self.dim, keepdim = True).detach()
return x / maxes
# https://arxiv.org/abs/2103.17239
class LayerScale(nn.Module):
def __init__(self, dim, depth, fn):
super().__init__()
if depth <= 18:
init_eps = 0.1
elif depth > 18 and depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
# layer norm
class PreNorm(nn.Module):
def __init__(self, dim, fn, sandwich = False):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
x = self.fn(x, **kwargs)
return self.norm_out(x)
# feed forward
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
class FeedForward(nn.Module):
def __init__(self, dim, dropout = 0., mult = 4.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
return self.net(x)
# token shift classes
class PreShiftToken(nn.Module):
def __init__(self, fn, image_size, seq_len):
super().__init__()
self.fn = fn
self.image_size = image_size
self.seq_len = seq_len
def forward(self, x, **kwargs):
n = x.shape[1]
seq_len, image_size = self.seq_len, self.image_size
img_seq_len = image_size ** 2
text_len = seq_len - img_seq_len + 1
padding = seq_len - n + 1
# get text and image tokens
x_text, x_img = x[:, :text_len], x[:, text_len:]
x_img = F.pad(x_img, (0, 0, 0, padding))
x_img = rearrange(x_img, 'b (h w) d -> b h w d', h = image_size)
# shift 1 from the left for text tokens
x_text_shift, x_text_pass = x_text.chunk(2, dim = -1)
x_text_shift = F.pad(x_text_shift, (0, 0, 1, -1))
x_text = torch.cat((x_text_shift, x_text_pass), dim = -1)
# shift from top, left for image tokens
x_img_shift_top, x_img_shift_left, *x_img_pass = x_img.chunk(4, dim = -1)
x_img_shift_left = F.pad(x_img_shift_left, (0, 0, 1, -1))
x_img_shift_top = F.pad(x_img_shift_top, (0, 0, 0, 0, 1, -1))
x_img = torch.cat((x_img_shift_top, x_img_shift_left, *x_img_pass), dim = -1)
# merge text and image sequence back together
x_img = rearrange(x_img, 'b h w d -> b (h w) d')
x = torch.cat((x_text, x_img[:, :-padding]), dim = 1)
return self.fn(x, **kwargs)
# main transformer class
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
seq_len,
reversible = False,
causal = True,
heads = 8,
dim_head = 64,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.,
attn_types = None,
image_fmap_size = None,
sparse_attn = False,
stable = False,
sandwich_norm = False,
shift_tokens = False,
rotary_emb = True
):
super().__init__()
layers = nn.ModuleList([])
sparse_layer = cast_tuple(sparse_attn, depth)
attn_types = default(attn_types, ('full',))
attn_types = cast_tuple(attn_types)
attn_type_layer = islice(cycle(attn_types), depth)
for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer):
if attn_type == 'full':
attn_class = partial(Attention, stable = stable)
elif attn_type == 'sparse':
attn_class = SparseAttention
elif attn_type == 'axial_row':
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable)
elif attn_type == 'axial_col':
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)
elif attn_type == 'conv_like':
attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable)
elif attn_type == 'mlp':
attn_class = partial(gMLPBlock, seq_len = seq_len)
else:
raise ValueError(f'attention type "{attn_type}" is not valid')
if attn_type != 'mlp':
attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
else:
attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4)
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
if shift_tokens:
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))
layers.append(nn.ModuleList([
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm))
]))
execute_type = ReversibleSequence if reversible else SequentialSequence
route_attn = ((True, False),) * depth
attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn}
self.layers = execute_type(layers, args_route = attn_route_map)
# generate positional embeddings for rotary
pos_emb = None
if rotary_emb:
assert 'mlp' not in attn_types, 'you cannot use gMLPs if rotary embedding is turned on'
rot_dim = dim_head // 3
img_seq_len = (image_fmap_size ** 2)
text_len = seq_len - img_seq_len + 1
text_pos_emb = RotaryEmbedding(dim = rot_dim)
img_axial_pos_emb = RotaryEmbedding(dim = rot_dim, freqs_for = 'pixel')
text_freqs = text_pos_emb(torch.arange(text_len))
img_to_text_freqs = text_pos_emb(torch.full((img_seq_len,), 8192)) # image is given a position far away from text
text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim = 0)
img_freqs_axial = img_axial_pos_emb(torch.linspace(-1, 1, steps = image_fmap_size))
img_freqs = broadcat((rearrange(img_freqs_axial, 'i d -> i () d'), rearrange(img_freqs_axial, 'j d -> () j d')), dim = -1)
img_freqs = rearrange(img_freqs, 'h w d -> (h w) d')
text_axial_freqs = img_axial_pos_emb(torch.full((text_len,), -10.)) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1]
text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim = -1)
img_freqs = torch.cat((text_axial_freqs, img_freqs), dim = 0)
pos_emb = torch.cat((text_freqs, img_freqs), dim = -1)
pos_emb = rearrange(pos_emb, 'n d -> () n d')
self.register_buffer('pos_emb', pos_emb)
def forward(self, x, **kwargs):
return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs)

View File

@ -11,18 +11,15 @@ munch
tqdm
scp
tensorboard
linear_attention_transformer
orjson
einops
lambda-networks
vector-quantize-pytorch
# For image generation stuff
opencv-python
kornia
pytorch_ssim
gsa-pytorch
vector_quantize_pytorch
pytorch_fid==0.1.1
# For audio generation stuff
@ -31,4 +28,15 @@ librosa==0.6.0
Unidecode==1.0.22
tgt == 1.4.4
pyworld == 0.2.10
audio2numpy
audio2numpy
# For text stuff
transformers
tokenizers
# lucidrains stuff
vector_quantize_pytorch
linear_attention_transformer
rotary-embedding-torch
axial_positional_embedding
g-mlp-pytorch