diff --git a/codes/models/gpt_voice/gpt_asr.py b/codes/models/gpt_voice/gpt_asr.py deleted file mode 100644 index fa1df98f..00000000 --- a/codes/models/gpt_voice/gpt_asr.py +++ /dev/null @@ -1,242 +0,0 @@ -from time import time - -import torch -import torch.nn as nn -import torch.nn.functional as F -from munch import munchify - -from models.gpt_voice.lucidrains_gpt import Transformer -from models.tacotron2.taco_utils import get_mask_from_lengths -from models.tacotron2.text import symbols, sequence_to_text -from trainer.networks import register_model -from utils.util import opt_get - - -class ResBlock(nn.Module): - def __init__(self, chan): - super().__init__() - self.net = nn.Sequential( - nn.Conv1d(chan, chan, kernel_size=5, padding = 2), - nn.BatchNorm1d(chan), - nn.ReLU(), - nn.Conv1d(chan, chan, kernel_size=5, padding = 2), - nn.BatchNorm1d(chan) - ) - - def forward(self, x): - return F.relu(self.net(x) + x) - - -class MelEncoder(nn.Module): - def __init__(self, channels, mel_channels=80): - super().__init__() - self.channels = channels - self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=7, padding=3), - ResBlock(channels//4), - ResBlock(channels//4), - nn.Conv1d(channels//4, channels//2, kernel_size=5, stride=2, padding=2), - nn.BatchNorm1d(channels//2), - nn.ReLU(), - ResBlock(channels//2), - ResBlock(channels//2), - ResBlock(channels//2), - nn.Conv1d(channels//2, channels, kernel_size=5, stride=2, padding=2), - ResBlock(channels), - ResBlock(channels), - ResBlock(channels) - ) - - def forward(self, x): - return self.encoder(x) - - -class GptAsr(nn.Module): - NUMBER_SYMBOLS = len(symbols) - NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1 - - def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_frames=1000): - super().__init__() - self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding. - self.max_symbols_per_phrase = max_symbols_per_phrase - - self.model_dim = model_dim - self.max_mel_frames = self.max_mel_frames - self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim) - self.mel_encoder = MelEncoder(model_dim) - self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) - self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim) - self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=2 + self.max_symbols_per_phrase + self.max_mel_frames, heads=heads, - attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.max_mel_frames) - - self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) - - def get_logits(self, mel_inputs, text_targets): - # Pad front and back. Pad at front is the "START" token. - text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS) - text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1])) - text_emb = self.text_embedding(text_targets) - text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device)) - mel_emb = self.mel_encoder(mel_inputs) - mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1])) - mel_emb = mel_emb.permute(0,2,1).contiguous() - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - emb = torch.cat([mel_emb, text_emb], dim=1) - enc = self.gpt(emb) - - text_logits = self.final_norm(enc[:, self.max_mel_frames:]) - text_logits = self.text_head(text_logits) - text_logits = text_logits.permute(0,2,1) - return text_logits - - def forward(self, mel_inputs, text_targets): - text_logits = self.get_logits(mel_inputs, text_targets) - loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long()) - return loss_text.mean(), text_logits - - def inference_beam_topk(self, mel, fn='inference_beam'): - def topk_sampler(distribution, k): - return torch.topk(distribution, k=k, dim=-1) - return getattr(self, fn)(mel, topk_sampler) - - def inference_beam_sampled(self, mel, fn='inference_beam'): - def multinomial_sampler(distribution, k): - indices = torch.multinomial(distribution, num_samples=k, replacement=False) - values = torch.gather(distribution, dim=1, index=indices) - class container: - def __init__(self, i, v): - self.indices = i - self.values = v - return container(indices, values) - return getattr(self, fn)(mel, multinomial_sampler) - - def inference_beam(self, mel_inputs, sampler_fn): - beam_width = 16 - temperature = .8 - - b, _, s = mel_inputs.shape - assert b == 1 # Beam search only works on batches of one. - mel_emb = self.mel_encoder(mel_inputs) - mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1])) - mel_emb = mel_emb.permute(0,2,1).contiguous() - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - - text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device) - probabilities = torch.ones((b,), device=mel_emb.device) - while text_seq.shape[-1] < self.max_symbols_per_phrase: - text_emb = self.text_embedding(text_seq) - text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device)) - if text_emb.shape[0] != mel_emb.shape[0]: - mel_emb = mel_emb.repeat(text_emb.shape[0], 1, 1) - emb = torch.cat([mel_emb, text_emb], dim=1) - enc = self.gpt(emb) - text_logits = self.final_norm(enc[:, mel_emb.shape[1]:]) - text_logits = self.text_head(text_logits) - topk = sampler_fn(F.softmax(temperature * text_logits[:, -1], dim=-1), k=beam_width) - probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten()) - probabilities, sort_indices = torch.sort(probabilities, descending=True) - probabilities = probabilities[:beam_width] - - text_seq = text_seq.repeat_interleave(beam_width, dim=0) - codes = topk.indices.flatten() - text_seq = torch.cat([text_seq, codes.unsqueeze(1)], dim=1) - text_seq = text_seq[sort_indices] - text_seq = text_seq[:beam_width] - - # PAD doubles as a stop token. PAD=0. - if torch.all(torch.any(text_seq == 0, dim=1)): - break - - if text_seq.shape[1] >= self.max_mel_frames: - print("Warning! Encountered frame limit before a pad token. Output is likely wrong.") - - return text_seq - - - def inference_beam_opt(self, mel_inputs, sampler_fn): - beam_width = 16 - temperature = .8 - - b, _, s = mel_inputs.shape - assert b == 1 # Beam search only works on batches of one. - mel_emb = self.mel_encoder(mel_inputs) - mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1])) - mel_emb = mel_emb.permute(0,2,1).contiguous() - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - - intermediates = [] - text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device) - probabilities = torch.ones((b,), device=mel_emb.device) - while text_seq.shape[-1] < self.max_symbols_per_phrase: - text_emb = self.text_embedding(text_seq) - text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device)) - if text_emb.shape[0] != mel_emb.shape[0]: - mel_emb = mel_emb.repeat(text_emb.shape[0], 1, 1) - emb = torch.cat([mel_emb, text_emb], dim=1) - - if len(intermediates) == 0: - enc, intermediates = self.gpt(emb, return_intermediates=True) - intermediates = [(i[0].repeat(beam_width, 1, 1), - i[1].repeat(beam_width, 1, 1)) for i in intermediates] - else: - enc, intermediates = self.gpt.infer_last_two(emb, intermediates) - - text_logits = self.final_norm(enc[:, mel_emb.shape[1]:]) - text_logits = self.text_head(text_logits) - topk = sampler_fn(F.softmax(temperature * text_logits[:, -1], dim=-1), k=beam_width) - probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten()) - probabilities, sort_indices = torch.sort(probabilities, descending=True) - probabilities = probabilities[:beam_width] - - text_seq = text_seq.repeat_interleave(beam_width, dim=0) - codes = topk.indices.flatten() - text_seq = torch.cat([text_seq, codes.unsqueeze(1)], dim=1) - text_seq = text_seq[sort_indices] - text_seq = text_seq[:beam_width] - - # PAD doubles as a stop token. PAD=0. - if torch.all(torch.any(text_seq == 0, dim=1)): - break - - if text_seq.shape[1] >= self.max_mel_frames: - print("Warning! Encountered frame limit before a pad token. Output is likely wrong.") - - return text_seq - - -@register_model -def register_gpt_asr(opt_net, opt): - return GptAsr(**opt_get(opt_net, ['kwargs'], {})) - - -# Quick script that loads a model and halves the number of layers, then saves that model. -def distill(): - gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12) - gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth')) - rc = 0 - i = 0 - while i < len(gpt.gpt.layers.layers): - if rc % 2 != 0: - del gpt.gpt.layers.layers[i] - else: - i += 1 - rc += 1 - torch.save(gpt.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth') - - -if __name__ == '__main__': - gpt = GptAsr(max_symbols_per_phrase=100, max_mel_frames=200, layers=6, model_dim=256, heads=2).cuda() - #l = gpt(torch.randn(2,80,800), - # torch.randint(high=len(symbols), size=(2,180))) - - with torch.no_grad(): - t = torch.randn(1,80,800).cuda() - start = time() - s = gpt.inference_beam_topk(t) - print(time()-start) - - start = time() - o = gpt.inference_beam_topk(t, fn='inference_beam_opt') - print(time()-start) - - diff --git a/codes/models/gpt_voice/gpt_audio_segmentor.py b/codes/models/gpt_voice/gpt_audio_segmentor.py deleted file mode 100644 index 8664bea0..00000000 --- a/codes/models/gpt_voice/gpt_audio_segmentor.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from munch import munchify - -from models.gpt_voice.lucidrains_gpt import Transformer -from models.tacotron2.taco_utils import get_mask_from_lengths -from models.tacotron2.text import symbols, sequence_to_text -from trainer.networks import register_model -from utils.util import opt_get - - -class ResBlock(nn.Module): - def __init__(self, chan): - super().__init__() - self.net = nn.Sequential( - nn.Conv1d(chan, chan, kernel_size=5, padding = 2), - nn.BatchNorm1d(chan), - nn.ReLU(), - nn.Conv1d(chan, chan, kernel_size=5, padding = 2), - nn.BatchNorm1d(chan) - ) - - def forward(self, x): - return F.relu(self.net(x) + x) - - -class MelEncoder(nn.Module): - def __init__(self, channels, mel_channels=80): - super().__init__() - self.channels = channels - self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=7, padding=3), - ResBlock(channels//4), - ResBlock(channels//4), - nn.Conv1d(channels//4, channels//2, kernel_size=5, stride=2, padding=2), - nn.BatchNorm1d(channels//2), - nn.ReLU(), - ResBlock(channels//2), - ResBlock(channels//2), - ResBlock(channels//2), - nn.Conv1d(channels//2, channels, kernel_size=5, stride=2, padding=2), - ResBlock(channels), - ResBlock(channels), - ResBlock(channels) - ) - - def forward(self, x): - return self.encoder(x) - - -class GptSegmentor(nn.Module): - MAX_MEL_FRAMES = 2000 // 4 - - def __init__(self, layers=8, model_dim=512, heads=8): - super().__init__() - - self.model_dim = model_dim - self.max_mel_frames = self.MAX_MEL_FRAMES - self.mel_encoder = MelEncoder(model_dim) - self.mel_pos_embedding = nn.Embedding(self.MAX_MEL_FRAMES, model_dim) - self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=self.MAX_MEL_FRAMES, heads=heads, - attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.MAX_MEL_FRAMES) - - self.final_norm = nn.LayerNorm(model_dim) - self.start_head = nn.Linear(model_dim, 1) - self.stop_head = nn.Linear(model_dim, 1) - - def forward(self, mel_inputs, start_labels=None, end_labels=None): - mel_emb = self.mel_encoder(mel_inputs) - mel_emb = mel_emb.permute(0,2,1).contiguous() - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - - enc = self.gpt(mel_emb) - logits = self.final_norm(enc) - stop_logits = self.stop_head(logits) - start_logits = self.start_head(logits) - - if start_labels is not None: - # Compute loss - start_loss = F.binary_cross_entropy_with_logits(start_logits.squeeze(-1), start_labels.float()) - end_loss = F.binary_cross_entropy_with_logits(stop_logits.squeeze(-1), end_labels.float()) - return start_loss.mean(), end_loss.mean() - else: - return start_logits, stop_logits - - - -@register_model -def register_gpt_segmentor(opt_net, opt): - return GptSegmentor(**opt_get(opt_net, ['kwargs'], {})) - - -if __name__ == '__main__': - gpt = GptSegmentor() - l = gpt(torch.randn(3,80,94), - torch.zeros(3,94)) - print(l.shape) - - #o = gpt.infer(torch.randint(high=24, size=(2,60))) - #print(o.shape) - - diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py deleted file mode 100644 index 43066dfa..00000000 --- a/codes/models/gpt_voice/gpt_tts.py +++ /dev/null @@ -1,176 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from munch import munchify - -from models.gpt_voice.lucidrains_gpt import Transformer -from models.tacotron2.taco_utils import get_mask_from_lengths -from models.tacotron2.text import symbols -from trainer.networks import register_model -from utils.util import opt_get - - -class GptTts(nn.Module): - MAX_SYMBOLS_PER_PHRASE = 200 - NUMBER_SYMBOLS = len(symbols) - NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2 - MEL_DICTIONARY_SIZE = 512+3 - MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3 - MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2 - - def __init__(self, layers=8, model_dim=512, heads=8): - super().__init__() - max_mel_frames = 900 * 1 // 4 # 900 is the max number of MEL frames. The VQVAE outputs 1/8 of the input mel as tokens. - - self.model_dim = model_dim - self.max_mel_frames = max_mel_frames - self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim) - self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim) - self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim) - self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim) - self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=heads, - attn_dropout=.1, ff_dropout=.1) - - self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) - self.mel_head = nn.Linear(model_dim, self.MEL_DICTIONARY_SIZE) - - def forward(self, text_inputs, text_lengths, mel_targets, output_lengths): - text_emb = self.text_embedding(text_inputs) - text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) - mel_emb = self.mel_embedding(mel_targets) - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_targets.shape[1], device=mel_targets.device)) - emb = torch.cat([text_emb, mel_emb], dim=1) - enc = self.gpt(emb) - - # Compute logits for text and mel heads - text_logits = self.final_norm(enc[:, :text_emb.shape[1]]) - mel_logits = self.final_norm(enc[:, text_emb.shape[1]:]) - text_logits = self.text_head(text_logits) - mel_logits = self.mel_head(mel_logits) - - # Compute loss - text_targets = text_inputs[:,1:] - text_logits = text_logits.permute(0,2,1)[:,:,:-1] # The last element of the logits is unneeded because the input to the transformer contains a token for both text and mel. - loss_text = F.cross_entropy(text_logits, text_targets, reduction='none') - mel_targets = mel_targets[:,1:] - mel_logits = mel_logits.permute(0,2,1)[:,:,:-1] - loss_mel = F.cross_entropy(mel_logits, mel_targets, reduction='none') - - # Fix up mel_logits so it can go into a VAE decoder as well. - mel_codes = torch.argmax(F.softmax(mel_logits, dim=1), dim=1) - mel_pad_mask = ~get_mask_from_lengths(output_lengths-1, mel_targets.shape[1]) - mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask, 0) - mel_codes = mel_codes[:,:-1] # Strip off token too (or padding). The important part is that the output sequence length is identical to the VAE input. - extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD - mel_codes = mel_codes * extra_mask - - # This class also returns the mel_targets for validation purposes. Format those. - mel_targets = mel_targets[:,:-1] - mel_targets = mel_targets * (mel_targets < self.MEL_DICTIONARY_SIZE-3) - return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets - - def inference(self, text_inputs): - b, s = text_inputs.shape - text_emb = self.text_embedding(text_inputs) - text_emb = text_emb + self.text_pos_embedding(torch.arange(s, device=text_inputs.device)) - - mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device) - stop_encountered = torch.zeros((b,), device=text_emb.device) - while not torch.all(stop_encountered) and len(mel_seq) < self.max_mel_frames: - mel_emb = self.mel_embedding(mel_seq) - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - emb = torch.cat([text_emb, mel_emb], dim=1) - enc = self.gpt(emb) - mel_logits = self.final_norm(enc[:, text_emb.shape[1]:]) - mel_logits = self.mel_head(mel_logits) - mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1) - mel_seq = torch.cat([mel_seq, mel_codes[:, -1].unsqueeze(1)], dim=1) - stop_encountered = torch.logical_or(stop_encountered, mel_seq[:,-1] == self.MEL_STOP_TOKEN) - - if len(mel_seq) >= self.max_mel_frames: - print("Warning! Encountered frame limit before a stop token. Output is likely wrong.") - - # Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE) - mel_seq = mel_seq[:, 1:-1] # Remove first and last tokens, which were artificially added for GPT - mel_seq = mel_seq * (mel_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens. - - return mel_seq - - def inference_beam_topk(self, text): - def topk_sampler(distribution, k): - return torch.topk(distribution, k=k, dim=-1) - return self.inference_beam(text, topk_sampler) - - def inference_beam_sampled(self, text): - def multinomial_sampler(distribution, k): - indices = torch.multinomial(distribution, num_samples=k, replacement=False) - values = torch.gather(distribution, dim=1, index=indices) - class container: - def __init__(self, i, v): - self.indices = i - self.values = v - return container(indices, values) - return self.inference_beam(text, multinomial_sampler) - - def inference_beam(self, text_inputs, sampler_fn): - beam_width = 16 - temperature = .8 - - b, s = text_inputs.shape - assert b == 1 # Beam search only works on batches of one. - text_emb = self.text_embedding(text_inputs) - text_emb = text_emb + self.text_pos_embedding(torch.arange(s, device=text_inputs.device)) - mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device) - probabilities = torch.ones((b,), device=text_emb.device) - while len(mel_seq) < self.max_mel_frames: - mel_emb = self.mel_embedding(mel_seq) - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - if text_emb.shape[0] != mel_emb.shape[0]: - text_emb = text_emb.repeat(mel_emb.shape[0], 1, 1) - emb = torch.cat([text_emb, mel_emb], dim=1) - enc = self.gpt(emb) - mel_logits = self.final_norm(enc[:, text_emb.shape[1]:]) - mel_logits = self.mel_head(mel_logits) - topk = sampler_fn(F.softmax(temperature * mel_logits[:, -1], dim=-1), k=beam_width) - probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten()) - probabilities, sort_indices = torch.sort(probabilities, descending=True) - probabilities = probabilities[:beam_width] - - mel_seq = mel_seq.repeat_interleave(beam_width, dim=0) - codes = topk.indices.flatten() - mel_seq = torch.cat([mel_seq, codes.unsqueeze(1)], dim=1) - mel_seq = mel_seq[sort_indices] - mel_seq = mel_seq[:beam_width] - - if torch.all(torch.any(mel_seq == self.MEL_STOP_TOKEN, dim=1)): - break - - if mel_seq.shape[1] >= self.max_mel_frames: - print("Warning! Encountered frame limit before a stop token. Output is likely wrong.") - - # Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE) - mel_seq = mel_seq[0, 1:-1].unsqueeze(0) # Pick most likely outcome, remove first and last tokens, which were artificially added for GPT - mel_seq = mel_seq * (mel_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens. - - return mel_seq - - - -@register_model -def register_gpt_tts(opt_net, opt): - return GptTts(**opt_get(opt_net, ['kwargs'], {})) - - -if __name__ == '__main__': - gpt = GptTts() - l1, l2, i = gpt(torch.randint(high=24, size=(2,60)), - torch.tensor([55,58]), - torch.randint(high=512, size=(2,310)), - torch.tensor([300,305])) - print(i.shape) - - #o = gpt.infer(torch.randint(high=24, size=(2,60))) - #print(o.shape) - - diff --git a/codes/models/gpt_voice/lucidrains_gpt.py b/codes/models/gpt_voice/lucidrains_gpt.py deleted file mode 100644 index b218f7fc..00000000 --- a/codes/models/gpt_voice/lucidrains_gpt.py +++ /dev/null @@ -1,236 +0,0 @@ -from inspect import isfunction - -import torch -from torch import nn, einsum -import torch.nn.functional as F -from einops import rearrange - -# helpers -from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence -from utils.util import checkpoint - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - -def cast_tuple(val, depth = 1): - if isinstance(val, list): - val = tuple(val) - return val if isinstance(val, tuple) else (val,) * depth - - -class DivideMax(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - maxes = x.amax(dim = self.dim, keepdim = True) - return x / maxes - - -# https://arxiv.org/abs/2103.17239 -class LayerScale(nn.Module): - def __init__(self, dim, depth, fn): - super().__init__() - if depth <= 18: - init_eps = 0.1 - elif depth > 18 and depth <= 24: - init_eps = 1e-5 - else: - init_eps = 1e-6 - - scale = torch.zeros(1, 1, dim).fill_(init_eps) - self.scale = nn.Parameter(scale) - self.fn = fn - - def forward(self, x, **kwargs): - return self.fn(x, **kwargs) * self.scale - - -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - - -class GEGLU(nn.Module): - def forward(self, x): - x, gates = x.chunk(2, dim = -1) - return x * F.gelu(gates) - - -class FeedForward(nn.Module): - def __init__(self, dim, dropout = 0., mult = 4.): - super().__init__() - self.net = nn.Sequential( - nn.Linear(dim, dim * mult * 2), - GEGLU(), - nn.Dropout(dropout), - nn.Linear(dim * mult, dim) - ) - - def forward(self, x, only_last_two_elements=False): - if only_last_two_elements: - h = x[:, -2:] - h = self.net(h) - return torch.cat([x[:, :-2], h], dim=1) - else: - return self.net(x) - - -def exists(val): - return val is not None - - -def uniq(arr): - return{el: True for el in arr}.keys() - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - -def stable_softmax(t, dim = -1, alpha = 32 ** 2): - t = t / alpha - t = t - torch.amax(t, dim = dim, keepdim = True) - return (t * alpha).softmax(dim = dim) - - -# classes -class Attention(nn.Module): - def __init__(self, dim, seq_len, non_causal_sequence_partition = 0, heads = 8, dim_head = 64, dropout = 0., stable = False): - super().__init__() - inner_dim = dim_head * heads - self.heads = heads - self.seq_len = seq_len - self.scale = dim_head ** -0.5 - - self.stable = stable - self.non_causal_sequence_partition = non_causal_sequence_partition - - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, dim), - nn.Dropout(dropout) - ) - - def forward(self, x, mask = None, only_last_two_elements=False): - b, n, _, h, device = *x.shape, self.heads, x.device - softmax = torch.softmax if not self.stable else stable_softmax - - # TODO: Q and V do not need to be recomputed for existing elements in intermediate_latents is specified. V would need to be cached though. - qkv = self.to_qkv(x).chunk(3, dim = -1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) - q = q * self.scale - - if only_last_two_elements: - q = q[:, :, -2:] - assert not exists(mask) # Don't know how to resolve this (currently) - - dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) - mask_value = max_neg_value(dots) - - if exists(mask): - mask = rearrange(mask, 'b j -> b () () j') - dots.masked_fill_(~mask, mask_value) - del mask - - i, j = dots.shape[-2:] - mask = torch.ones(i, j, device = device).triu_(j - i + 1) - if self.non_causal_sequence_partition > 0: - non_causal_mask = torch.ones((i, j), device=device) - non_causal_mask[:, :self.non_causal_sequence_partition] = 0 - mask = mask * non_causal_mask - - dots.masked_fill_(mask.bool(), mask_value) - - attn = softmax(dots, dim=-1) - - out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') - out = self.to_out(out) - return out - - -class Transformer(nn.Module): - def __init__( - self, - *, - dim, - depth, - seq_len, - reversible = False, - heads = 8, - dim_head = 64, - ff_mult = 4, - attn_dropout = 0., - ff_dropout = 0., - sparse_attn = False, - stable = False, - non_causal_sequence_partition=0, - ): - super().__init__() - layers = nn.ModuleList([]) - sparse_layer = cast_tuple(sparse_attn, depth) - - for ind, sparse_attn in zip(range(depth), sparse_layer): - attn = Attention(dim, stable=stable, non_causal_sequence_partition = non_causal_sequence_partition, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) - - layers.append(nn.ModuleList([ - LayerScale(dim, ind + 1, PreNorm(dim, attn)), - LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))) - ])) - - # TODO: Remove this nonsense. I don't want to support reversible sequences and this is just a mess. - execute_type = ReversibleSequence if reversible else SequentialSequence - route_attn = ((True, False),) * depth - attn_route_map = {'mask': route_attn} - - self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True) - self.depth = depth - - def forward(self, x, return_intermediates=False): - intermediates = [] - for attn, ff in self.layers.layers: - x_ff = x + checkpoint(attn, x) - x = x_ff + ff(x_ff) - if return_intermediates: - intermediates.append((x_ff, x)) - if return_intermediates: - return x, intermediates - else: - return x - - def infer_last_two(self, x, prev_intermediates): - """ - Performs an forward pass only on the last two element in the given sequence (allowing it to attend to all other - elements). This is useful for faster autoregressive decoding. - - The last two elements are important because in inference, the last element is the prediction candidate and the - second-to-last element is a newly selected element from the autoregressive searching process. - """ - assert(len(prev_intermediates) == self.depth) - new_intermediates = [] - for (attn, ff), (int_ff, int_out) in zip(self.layers.layers, prev_intermediates): - x_ff = attn(x, only_last_two_elements=True) - # Note that (x) is now only the last two element in the set. Conjoin it with the int_ff latent to compute the norm. - x_ff = x + torch.cat([int_ff[:,:-1], x_ff], dim=1) - x = x_ff + ff(x_ff, only_last_two_elements=True) - new_intermediates.append((x_ff, x)) - return x, new_intermediates diff --git a/codes/models/gpt_voice/voice_clip.py b/codes/models/gpt_voice/voice_clip.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/lucidrains/dalle/__init__.py b/codes/models/lucidrains/dalle/__init__.py new file mode 100644 index 00000000..d8f37633 --- /dev/null +++ b/codes/models/lucidrains/dalle/__init__.py @@ -0,0 +1 @@ +# This directory contains some useful code from https://github.com/lucidrains/DALLE-pytorch/tree/main/dalle_pytorch \ No newline at end of file diff --git a/codes/models/lucidrains/dalle/attention.py b/codes/models/lucidrains/dalle/attention.py new file mode 100644 index 00000000..c3b52cc4 --- /dev/null +++ b/codes/models/lucidrains/dalle/attention.py @@ -0,0 +1,384 @@ +from inspect import isfunction +from math import ceil + +import torch +from torch import nn, einsum +import torch.nn.functional as F +from einops import rearrange, repeat + +from rotary_embedding_torch import apply_rotary_emb + +# helpers + +def exists(val): + return val is not None + +def uniq(arr): + return{el: True for el in arr}.keys() + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + +def stable_softmax(t, dim = -1, alpha = 32 ** 2): + t = t / alpha + t = t - torch.amax(t, dim = dim, keepdim = True).detach() + return (t * alpha).softmax(dim = dim) + +def apply_pos_emb(pos_emb, qkv): + n = qkv[0].shape[-2] + pos_emb = pos_emb[..., :n, :] + return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv)) + +# classes + +class Attention(nn.Module): + def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.seq_len = seq_len + self.scale = dim_head ** -0.5 + + self.stable = stable + self.causal = causal + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, mask = None, rotary_pos_emb = None): + b, n, _, h, device = *x.shape, self.heads, x.device + softmax = torch.softmax if not self.stable else stable_softmax + + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + if exists(rotary_pos_emb): + q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) + + q = q * self.scale + + dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) + mask_value = max_neg_value(dots) + + if exists(mask): + mask = rearrange(mask, 'b j -> b () () j') + dots.masked_fill_(~mask, mask_value) + del mask + + if self.causal: + i, j = dots.shape[-2:] + mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() + dots.masked_fill_(mask, mask_value) + + attn = softmax(dots, dim=-1) + + out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + +# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation + +class SparseConvCausalAttention(nn.Module): + def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs): + super().__init__() + assert kernel_size % 2 == 1, 'kernel size must be odd' + + inner_dim = dim_head * heads + self.seq_len = seq_len + self.heads = heads + self.scale = dim_head ** -0.5 + self.image_size = image_size + self.kernel_size = kernel_size + self.dilation = dilation + + self.stable = stable + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, mask = None, rotary_pos_emb = None): + b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device + softmax = torch.softmax if not self.stable else stable_softmax + + img_seq_len = img_size ** 2 + text_len = seq_len + 1 - img_seq_len + + # padding + + padding = seq_len - n + 1 + mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool()) + + x = F.pad(x, (0, 0, 0, padding), value = 0) + mask = mask[:, :text_len] + + # derive query / keys / values + + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv) + + if exists(rotary_pos_emb): + q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) + + q *= self.scale + + ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)) + + # text attention + + dots_text = einsum('b i d, b j d -> b i j', q_text, k_text) + mask_value = max_neg_value(dots_text) + + i, j = dots_text.shape[-2:] + text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() + dots_text.masked_fill_(text_causal_mask, mask_value) + + attn_text = softmax(dots_text, dim = -1) + out_text = einsum('b i j, b j d -> b i d', attn_text, v_text) + + # image attention + + effective_kernel_size = (kernel_size - 1) * dilation + 1 + padding = effective_kernel_size // 2 + + k_img, v_img = map(lambda t: rearrange(t, 'b (h w) c -> b c h w', h = img_size), (k_img, v_img)) + k_img, v_img = map(lambda t: F.unfold(t, kernel_size, padding = padding, dilation = dilation), (k_img, v_img)) + k_img, v_img = map(lambda t: rearrange(t, 'b (d j) i -> b i j d', j = kernel_size ** 2), (k_img, v_img)) + + # let image attend to all of text + + dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img) + dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text) + + # calculate causal attention for local convolution + + i, j = dots_image.shape[-2:] + img_seq = torch.arange(img_seq_len, device = device) + k_img_indices = rearrange(img_seq.float(), '(h w) -> () () h w', h = img_size) + k_img_indices = F.pad(k_img_indices, (padding,) * 4, value = img_seq_len) # padding set to be max, so it is never attended to + k_img_indices = F.unfold(k_img_indices, kernel_size, dilation = dilation) + k_img_indices = rearrange(k_img_indices, 'b j i -> b i j') + + # mask image attention + + q_img_indices = rearrange(img_seq, 'i -> () i ()') + causal_mask = q_img_indices < k_img_indices + + # concat text mask with image causal mask + + causal_mask = repeat(causal_mask, '() i j -> b i j', b = b * h) + mask = repeat(mask, 'b j -> (b h) i j', i = i, h = h) + mask = torch.cat((~mask, causal_mask), dim = -1) + + # image can attend to all of text + + dots = torch.cat((dots_image_to_text, dots_image), dim = -1) + dots.masked_fill_(mask, mask_value) + + attn = softmax(dots, dim = -1) + + # aggregate + + attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:] + + out_image_to_image = einsum('b i j, b i j d -> b i d', attn_image, v_img) + out_image_to_text = einsum('b i j, b j d -> b i d', attn_image_to_text, v_text) + + out_image = out_image_to_image + out_image_to_text + + # combine attended values for both text and image + + out = torch.cat((out_text, out_image), dim = 1) + + out = rearrange(out, '(b h) n d -> b n (h d)', h = h) + out = self.to_out(out) + return out[:, :n] + +# sparse axial causal attention + +class SparseAxialCausalAttention(nn.Module): + def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs): + super().__init__() + assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)' + self.axis = axis + + inner_dim = dim_head * heads + self.seq_len = seq_len + self.heads = heads + self.scale = dim_head ** -0.5 + self.image_size = image_size + + self.stable = stable + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, mask = None, rotary_pos_emb = None): + b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device + softmax = torch.softmax if not self.stable else stable_softmax + + img_seq_len = img_size ** 2 + text_len = seq_len + 1 - img_seq_len + + # padding + + padding = seq_len - n + 1 + mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool()) + + x = F.pad(x, (0, 0, 0, padding), value = 0) + mask = mask[:, :text_len] + + # derive queries / keys / values + + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv) + + if exists(rotary_pos_emb): + q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) + + q *= self.scale + + ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)) + + # text attention + + dots_text = einsum('b i d, b j d -> b i j', q_text, k_text) + mask_value = max_neg_value(dots_text) + + i, j = dots_text.shape[-2:] + text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() + dots_text.masked_fill_(text_causal_mask, mask_value) + + attn_text = softmax(dots_text, dim = -1) + out_text = einsum('b i j, b j d -> b i d', attn_text, v_text) + + # image attention + + split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c' + merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d' + + # split out axis + + q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img)) + + # similarity + + dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img) + dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text) + + dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1) + + # mask so image has full attention to text, but causal along axis + + bh, x, i, j = dots.shape + causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool() + causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x) + + mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i) + mask = torch.cat((~mask, causal_mask), dim = -1) + + dots.masked_fill_(mask, mask_value) + + # attention. + + attn = softmax(dots, dim = -1) + + # aggregate + + attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:] + + out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img) + out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text) + + out_image = out_image_to_image + out_image_to_text + + # merge back axis + + out_image = rearrange(out_image, merge_axis_einops, x = img_size) + + # combine attended values for both text and image + + out = torch.cat((out_text, out_image), dim = 1) + + out = rearrange(out, '(b h) n d -> b n (h d)', h = h) + out = self.to_out(out) + return out[:, :n] + +# microsoft sparse attention CUDA kernel + +class SparseAttention(Attention): + def __init__( + self, + *args, + block_size = 16, + text_seq_len = 256, + num_random_blocks = None, + **kwargs + ): + super().__init__(*args, **kwargs) + from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig + self.block_size = block_size + + num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4) + global_block_indices = list(range(ceil(text_seq_len / block_size))) + + self.attn_fn = SparseSelfAttention( + sparsity_config = VariableSparsityConfig( + num_heads = self.heads, + block = self.block_size, + num_random_blocks = num_random_blocks, + global_block_indices = global_block_indices, + attention = 'unidirectional' if self.causal else 'bidirectional' + ), + max_seq_length = self.seq_len, + attn_mask_mode = 'add' + ) + + def forward(self, x, mask = None, rotary_pos_emb = None): + b, n, _, h, device = *x.shape, self.heads, x.device + remainder = n % self.block_size + mask = default(mask, lambda: torch.ones(b, n, device = device).bool()) + + if remainder > 0: + padding = self.block_size - remainder + x = F.pad(x, (0, 0, 0, padding), value = 0) + mask = F.pad(mask, (0, padding), value = False) + + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + if exists(rotary_pos_emb): + q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) + + key_pad_mask = None + if exists(mask): + key_pad_mask = ~mask + + attn_mask = None + if self.causal: + i, j = q.shape[-2], k.shape[-2] + mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() + attn_mask = torch.zeros(i, j, device = device).to(q) + mask_value = max_neg_value(q) / 2 + attn_mask.masked_fill_(mask, mask_value) + + out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out[:, :n] \ No newline at end of file diff --git a/codes/models/gpt_voice/reversible.py b/codes/models/lucidrains/dalle/reversible.py similarity index 94% rename from codes/models/gpt_voice/reversible.py rename to codes/models/lucidrains/dalle/reversible.py index 481a2927..a235323a 100644 --- a/codes/models/gpt_voice/reversible.py +++ b/codes/models/lucidrains/dalle/reversible.py @@ -1,12 +1,10 @@ import torch import torch.nn as nn +from operator import itemgetter from torch.autograd.function import Function from torch.utils.checkpoint import get_device_states, set_device_states # for routing arguments into the functions of the reversible layer -from utils.util import checkpoint - - def route_args(router, args, depth): routed_args = [(dict(), dict()) for _ in range(depth)] matched_keys = [key for key in args.keys() if key in router] @@ -126,25 +124,20 @@ class _ReversibleFunction(Function): return dy, None, None class SequentialSequence(nn.Module): - def __init__(self, layers, args_route = {}, layer_dropout = 0., checkpoint=False): + def __init__(self, layers, args_route = {}, layer_dropout = 0.): super().__init__() assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' self.layers = layers self.args_route = args_route self.layer_dropout = layer_dropout - self.checkpoint = checkpoint def forward(self, x, **kwargs): args = route_args(self.args_route, kwargs, len(self.layers)) layers_and_args = list(zip(self.layers, args)) for (f, g), (f_args, g_args) in layers_and_args: - if self.checkpoint: - x = x + f(x, **f_args) - x = x + g(x, **g_args) - else: - x = x + checkpoint(f, x, **f_args) - x = x + checkpoint(g, x, **g_args) + x = x + f(x, **f_args) + x = x + g(x, **g_args) return x class ReversibleSequence(nn.Module): diff --git a/codes/models/lucidrains/dalle/transformer.py b/codes/models/lucidrains/dalle/transformer.py new file mode 100644 index 00000000..27f5f2bd --- /dev/null +++ b/codes/models/lucidrains/dalle/transformer.py @@ -0,0 +1,231 @@ +from functools import partial +from itertools import islice, cycle + +import torch +from torch import nn, einsum +import torch.nn.functional as F +from einops import rearrange + +from models.lucidrains.dalle.reversible import ReversibleSequence, SequentialSequence +from models.lucidrains.dalle.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention + +from rotary_embedding_torch import RotaryEmbedding, broadcat +from g_mlp_pytorch import gMLPBlock + +# helpers + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def cast_tuple(val, depth = 1): + if isinstance(val, list): + val = tuple(val) + return val if isinstance(val, tuple) else (val,) * depth + +# classes + +class DivideMax(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + maxes = x.amax(dim = self.dim, keepdim = True).detach() + return x / maxes + +# https://arxiv.org/abs/2103.17239 +class LayerScale(nn.Module): + def __init__(self, dim, depth, fn): + super().__init__() + if depth <= 18: + init_eps = 0.1 + elif depth > 18 and depth <= 24: + init_eps = 1e-5 + else: + init_eps = 1e-6 + + scale = torch.zeros(1, 1, dim).fill_(init_eps) + self.scale = nn.Parameter(scale) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) * self.scale + +# layer norm + +class PreNorm(nn.Module): + def __init__(self, dim, fn, sandwich = False): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() + self.fn = fn + + def forward(self, x, **kwargs): + x = self.norm(x) + x = self.fn(x, **kwargs) + return self.norm_out(x) + +# feed forward + +class GEGLU(nn.Module): + def forward(self, x): + x, gates = x.chunk(2, dim = -1) + return x * F.gelu(gates) + +class FeedForward(nn.Module): + def __init__(self, dim, dropout = 0., mult = 4.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim * mult * 2), + GEGLU(), + nn.Dropout(dropout), + nn.Linear(dim * mult, dim) + ) + + def forward(self, x): + return self.net(x) + +# token shift classes + +class PreShiftToken(nn.Module): + def __init__(self, fn, image_size, seq_len): + super().__init__() + self.fn = fn + self.image_size = image_size + self.seq_len = seq_len + + def forward(self, x, **kwargs): + n = x.shape[1] + seq_len, image_size = self.seq_len, self.image_size + img_seq_len = image_size ** 2 + text_len = seq_len - img_seq_len + 1 + padding = seq_len - n + 1 + + # get text and image tokens + + x_text, x_img = x[:, :text_len], x[:, text_len:] + x_img = F.pad(x_img, (0, 0, 0, padding)) + x_img = rearrange(x_img, 'b (h w) d -> b h w d', h = image_size) + + # shift 1 from the left for text tokens + + x_text_shift, x_text_pass = x_text.chunk(2, dim = -1) + x_text_shift = F.pad(x_text_shift, (0, 0, 1, -1)) + x_text = torch.cat((x_text_shift, x_text_pass), dim = -1) + + # shift from top, left for image tokens + + x_img_shift_top, x_img_shift_left, *x_img_pass = x_img.chunk(4, dim = -1) + x_img_shift_left = F.pad(x_img_shift_left, (0, 0, 1, -1)) + x_img_shift_top = F.pad(x_img_shift_top, (0, 0, 0, 0, 1, -1)) + x_img = torch.cat((x_img_shift_top, x_img_shift_left, *x_img_pass), dim = -1) + + # merge text and image sequence back together + + x_img = rearrange(x_img, 'b h w d -> b (h w) d') + x = torch.cat((x_text, x_img[:, :-padding]), dim = 1) + return self.fn(x, **kwargs) + +# main transformer class + +class Transformer(nn.Module): + def __init__( + self, + *, + dim, + depth, + seq_len, + reversible = False, + causal = True, + heads = 8, + dim_head = 64, + ff_mult = 4, + attn_dropout = 0., + ff_dropout = 0., + attn_types = None, + image_fmap_size = None, + sparse_attn = False, + stable = False, + sandwich_norm = False, + shift_tokens = False, + rotary_emb = True + ): + super().__init__() + layers = nn.ModuleList([]) + sparse_layer = cast_tuple(sparse_attn, depth) + + attn_types = default(attn_types, ('full',)) + attn_types = cast_tuple(attn_types) + attn_type_layer = islice(cycle(attn_types), depth) + + for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer): + if attn_type == 'full': + attn_class = partial(Attention, stable = stable) + elif attn_type == 'sparse': + attn_class = SparseAttention + elif attn_type == 'axial_row': + attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable) + elif attn_type == 'axial_col': + attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable) + elif attn_type == 'conv_like': + attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable) + elif attn_type == 'mlp': + attn_class = partial(gMLPBlock, seq_len = seq_len) + else: + raise ValueError(f'attention type "{attn_type}" is not valid') + + if attn_type != 'mlp': + attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) + else: + attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4) + + ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout) + + if shift_tokens: + attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff)) + + layers.append(nn.ModuleList([ + LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)), + LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm)) + ])) + + execute_type = ReversibleSequence if reversible else SequentialSequence + route_attn = ((True, False),) * depth + attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn} + + self.layers = execute_type(layers, args_route = attn_route_map) + + # generate positional embeddings for rotary + + pos_emb = None + if rotary_emb: + assert 'mlp' not in attn_types, 'you cannot use gMLPs if rotary embedding is turned on' + + rot_dim = dim_head // 3 + img_seq_len = (image_fmap_size ** 2) + text_len = seq_len - img_seq_len + 1 + + text_pos_emb = RotaryEmbedding(dim = rot_dim) + img_axial_pos_emb = RotaryEmbedding(dim = rot_dim, freqs_for = 'pixel') + + text_freqs = text_pos_emb(torch.arange(text_len)) + img_to_text_freqs = text_pos_emb(torch.full((img_seq_len,), 8192)) # image is given a position far away from text + text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim = 0) + + img_freqs_axial = img_axial_pos_emb(torch.linspace(-1, 1, steps = image_fmap_size)) + img_freqs = broadcat((rearrange(img_freqs_axial, 'i d -> i () d'), rearrange(img_freqs_axial, 'j d -> () j d')), dim = -1) + img_freqs = rearrange(img_freqs, 'h w d -> (h w) d') + + text_axial_freqs = img_axial_pos_emb(torch.full((text_len,), -10.)) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1] + text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim = -1) + img_freqs = torch.cat((text_axial_freqs, img_freqs), dim = 0) + + pos_emb = torch.cat((text_freqs, img_freqs), dim = -1) + pos_emb = rearrange(pos_emb, 'n d -> () n d') + + self.register_buffer('pos_emb', pos_emb) + + def forward(self, x, **kwargs): + return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs) \ No newline at end of file diff --git a/codes/requirements.txt b/codes/requirements.txt index 1b1a6951..73355204 100644 --- a/codes/requirements.txt +++ b/codes/requirements.txt @@ -11,18 +11,15 @@ munch tqdm scp tensorboard -linear_attention_transformer orjson einops lambda-networks -vector-quantize-pytorch # For image generation stuff opencv-python kornia pytorch_ssim gsa-pytorch -vector_quantize_pytorch pytorch_fid==0.1.1 # For audio generation stuff @@ -31,4 +28,15 @@ librosa==0.6.0 Unidecode==1.0.22 tgt == 1.4.4 pyworld == 0.2.10 -audio2numpy \ No newline at end of file +audio2numpy + +# For text stuff +transformers +tokenizers + +# lucidrains stuff +vector_quantize_pytorch +linear_attention_transformer +rotary-embedding-torch +axial_positional_embedding +g-mlp-pytorch \ No newline at end of file