forked from mrq/DL-Art-School
Remove obsolete lucidrains DALLE stuff, re-create it in a dedicated folder
This commit is contained in:
parent
a42b94ab72
commit
09f7f3e615
|
@ -1,242 +0,0 @@
|
|||
from time import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from munch import munchify
|
||||
|
||||
from models.gpt_voice.lucidrains_gpt import Transformer
|
||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from models.tacotron2.text import symbols, sequence_to_text
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, chan):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
|
||||
nn.BatchNorm1d(chan),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
|
||||
nn.BatchNorm1d(chan)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return F.relu(self.net(x) + x)
|
||||
|
||||
|
||||
class MelEncoder(nn.Module):
|
||||
def __init__(self, channels, mel_channels=80):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=7, padding=3),
|
||||
ResBlock(channels//4),
|
||||
ResBlock(channels//4),
|
||||
nn.Conv1d(channels//4, channels//2, kernel_size=5, stride=2, padding=2),
|
||||
nn.BatchNorm1d(channels//2),
|
||||
nn.ReLU(),
|
||||
ResBlock(channels//2),
|
||||
ResBlock(channels//2),
|
||||
ResBlock(channels//2),
|
||||
nn.Conv1d(channels//2, channels, kernel_size=5, stride=2, padding=2),
|
||||
ResBlock(channels),
|
||||
ResBlock(channels),
|
||||
ResBlock(channels)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
|
||||
class GptAsr(nn.Module):
|
||||
NUMBER_SYMBOLS = len(symbols)
|
||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1
|
||||
|
||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_frames=1000):
|
||||
super().__init__()
|
||||
self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding.
|
||||
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||
|
||||
self.model_dim = model_dim
|
||||
self.max_mel_frames = self.max_mel_frames
|
||||
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
||||
self.mel_encoder = MelEncoder(model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim)
|
||||
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=2 + self.max_symbols_per_phrase + self.max_mel_frames, heads=heads,
|
||||
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.max_mel_frames)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||
|
||||
def get_logits(self, mel_inputs, text_targets):
|
||||
# Pad front and back. Pad at front is the "START" token.
|
||||
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
|
||||
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
|
||||
text_emb = self.text_embedding(text_targets)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
enc = self.gpt(emb)
|
||||
|
||||
text_logits = self.final_norm(enc[:, self.max_mel_frames:])
|
||||
text_logits = self.text_head(text_logits)
|
||||
text_logits = text_logits.permute(0,2,1)
|
||||
return text_logits
|
||||
|
||||
def forward(self, mel_inputs, text_targets):
|
||||
text_logits = self.get_logits(mel_inputs, text_targets)
|
||||
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
|
||||
return loss_text.mean(), text_logits
|
||||
|
||||
def inference_beam_topk(self, mel, fn='inference_beam'):
|
||||
def topk_sampler(distribution, k):
|
||||
return torch.topk(distribution, k=k, dim=-1)
|
||||
return getattr(self, fn)(mel, topk_sampler)
|
||||
|
||||
def inference_beam_sampled(self, mel, fn='inference_beam'):
|
||||
def multinomial_sampler(distribution, k):
|
||||
indices = torch.multinomial(distribution, num_samples=k, replacement=False)
|
||||
values = torch.gather(distribution, dim=1, index=indices)
|
||||
class container:
|
||||
def __init__(self, i, v):
|
||||
self.indices = i
|
||||
self.values = v
|
||||
return container(indices, values)
|
||||
return getattr(self, fn)(mel, multinomial_sampler)
|
||||
|
||||
def inference_beam(self, mel_inputs, sampler_fn):
|
||||
beam_width = 16
|
||||
temperature = .8
|
||||
|
||||
b, _, s = mel_inputs.shape
|
||||
assert b == 1 # Beam search only works on batches of one.
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
|
||||
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
|
||||
probabilities = torch.ones((b,), device=mel_emb.device)
|
||||
while text_seq.shape[-1] < self.max_symbols_per_phrase:
|
||||
text_emb = self.text_embedding(text_seq)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
|
||||
if text_emb.shape[0] != mel_emb.shape[0]:
|
||||
mel_emb = mel_emb.repeat(text_emb.shape[0], 1, 1)
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
enc = self.gpt(emb)
|
||||
text_logits = self.final_norm(enc[:, mel_emb.shape[1]:])
|
||||
text_logits = self.text_head(text_logits)
|
||||
topk = sampler_fn(F.softmax(temperature * text_logits[:, -1], dim=-1), k=beam_width)
|
||||
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
|
||||
probabilities, sort_indices = torch.sort(probabilities, descending=True)
|
||||
probabilities = probabilities[:beam_width]
|
||||
|
||||
text_seq = text_seq.repeat_interleave(beam_width, dim=0)
|
||||
codes = topk.indices.flatten()
|
||||
text_seq = torch.cat([text_seq, codes.unsqueeze(1)], dim=1)
|
||||
text_seq = text_seq[sort_indices]
|
||||
text_seq = text_seq[:beam_width]
|
||||
|
||||
# PAD doubles as a stop token. PAD=0.
|
||||
if torch.all(torch.any(text_seq == 0, dim=1)):
|
||||
break
|
||||
|
||||
if text_seq.shape[1] >= self.max_mel_frames:
|
||||
print("Warning! Encountered frame limit before a pad token. Output is likely wrong.")
|
||||
|
||||
return text_seq
|
||||
|
||||
|
||||
def inference_beam_opt(self, mel_inputs, sampler_fn):
|
||||
beam_width = 16
|
||||
temperature = .8
|
||||
|
||||
b, _, s = mel_inputs.shape
|
||||
assert b == 1 # Beam search only works on batches of one.
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
|
||||
intermediates = []
|
||||
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
|
||||
probabilities = torch.ones((b,), device=mel_emb.device)
|
||||
while text_seq.shape[-1] < self.max_symbols_per_phrase:
|
||||
text_emb = self.text_embedding(text_seq)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
|
||||
if text_emb.shape[0] != mel_emb.shape[0]:
|
||||
mel_emb = mel_emb.repeat(text_emb.shape[0], 1, 1)
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
|
||||
if len(intermediates) == 0:
|
||||
enc, intermediates = self.gpt(emb, return_intermediates=True)
|
||||
intermediates = [(i[0].repeat(beam_width, 1, 1),
|
||||
i[1].repeat(beam_width, 1, 1)) for i in intermediates]
|
||||
else:
|
||||
enc, intermediates = self.gpt.infer_last_two(emb, intermediates)
|
||||
|
||||
text_logits = self.final_norm(enc[:, mel_emb.shape[1]:])
|
||||
text_logits = self.text_head(text_logits)
|
||||
topk = sampler_fn(F.softmax(temperature * text_logits[:, -1], dim=-1), k=beam_width)
|
||||
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
|
||||
probabilities, sort_indices = torch.sort(probabilities, descending=True)
|
||||
probabilities = probabilities[:beam_width]
|
||||
|
||||
text_seq = text_seq.repeat_interleave(beam_width, dim=0)
|
||||
codes = topk.indices.flatten()
|
||||
text_seq = torch.cat([text_seq, codes.unsqueeze(1)], dim=1)
|
||||
text_seq = text_seq[sort_indices]
|
||||
text_seq = text_seq[:beam_width]
|
||||
|
||||
# PAD doubles as a stop token. PAD=0.
|
||||
if torch.all(torch.any(text_seq == 0, dim=1)):
|
||||
break
|
||||
|
||||
if text_seq.shape[1] >= self.max_mel_frames:
|
||||
print("Warning! Encountered frame limit before a pad token. Output is likely wrong.")
|
||||
|
||||
return text_seq
|
||||
|
||||
|
||||
@register_model
|
||||
def register_gpt_asr(opt_net, opt):
|
||||
return GptAsr(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
# Quick script that loads a model and halves the number of layers, then saves that model.
|
||||
def distill():
|
||||
gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12)
|
||||
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
|
||||
rc = 0
|
||||
i = 0
|
||||
while i < len(gpt.gpt.layers.layers):
|
||||
if rc % 2 != 0:
|
||||
del gpt.gpt.layers.layers[i]
|
||||
else:
|
||||
i += 1
|
||||
rc += 1
|
||||
torch.save(gpt.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = GptAsr(max_symbols_per_phrase=100, max_mel_frames=200, layers=6, model_dim=256, heads=2).cuda()
|
||||
#l = gpt(torch.randn(2,80,800),
|
||||
# torch.randint(high=len(symbols), size=(2,180)))
|
||||
|
||||
with torch.no_grad():
|
||||
t = torch.randn(1,80,800).cuda()
|
||||
start = time()
|
||||
s = gpt.inference_beam_topk(t)
|
||||
print(time()-start)
|
||||
|
||||
start = time()
|
||||
o = gpt.inference_beam_topk(t, fn='inference_beam_opt')
|
||||
print(time()-start)
|
||||
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from munch import munchify
|
||||
|
||||
from models.gpt_voice.lucidrains_gpt import Transformer
|
||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from models.tacotron2.text import symbols, sequence_to_text
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, chan):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
|
||||
nn.BatchNorm1d(chan),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
|
||||
nn.BatchNorm1d(chan)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return F.relu(self.net(x) + x)
|
||||
|
||||
|
||||
class MelEncoder(nn.Module):
|
||||
def __init__(self, channels, mel_channels=80):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=7, padding=3),
|
||||
ResBlock(channels//4),
|
||||
ResBlock(channels//4),
|
||||
nn.Conv1d(channels//4, channels//2, kernel_size=5, stride=2, padding=2),
|
||||
nn.BatchNorm1d(channels//2),
|
||||
nn.ReLU(),
|
||||
ResBlock(channels//2),
|
||||
ResBlock(channels//2),
|
||||
ResBlock(channels//2),
|
||||
nn.Conv1d(channels//2, channels, kernel_size=5, stride=2, padding=2),
|
||||
ResBlock(channels),
|
||||
ResBlock(channels),
|
||||
ResBlock(channels)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
|
||||
class GptSegmentor(nn.Module):
|
||||
MAX_MEL_FRAMES = 2000 // 4
|
||||
|
||||
def __init__(self, layers=8, model_dim=512, heads=8):
|
||||
super().__init__()
|
||||
|
||||
self.model_dim = model_dim
|
||||
self.max_mel_frames = self.MAX_MEL_FRAMES
|
||||
self.mel_encoder = MelEncoder(model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(self.MAX_MEL_FRAMES, model_dim)
|
||||
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=self.MAX_MEL_FRAMES, heads=heads,
|
||||
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.MAX_MEL_FRAMES)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.start_head = nn.Linear(model_dim, 1)
|
||||
self.stop_head = nn.Linear(model_dim, 1)
|
||||
|
||||
def forward(self, mel_inputs, start_labels=None, end_labels=None):
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
|
||||
enc = self.gpt(mel_emb)
|
||||
logits = self.final_norm(enc)
|
||||
stop_logits = self.stop_head(logits)
|
||||
start_logits = self.start_head(logits)
|
||||
|
||||
if start_labels is not None:
|
||||
# Compute loss
|
||||
start_loss = F.binary_cross_entropy_with_logits(start_logits.squeeze(-1), start_labels.float())
|
||||
end_loss = F.binary_cross_entropy_with_logits(stop_logits.squeeze(-1), end_labels.float())
|
||||
return start_loss.mean(), end_loss.mean()
|
||||
else:
|
||||
return start_logits, stop_logits
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def register_gpt_segmentor(opt_net, opt):
|
||||
return GptSegmentor(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = GptSegmentor()
|
||||
l = gpt(torch.randn(3,80,94),
|
||||
torch.zeros(3,94))
|
||||
print(l.shape)
|
||||
|
||||
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
|
||||
#print(o.shape)
|
||||
|
||||
|
|
@ -1,176 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from munch import munchify
|
||||
|
||||
from models.gpt_voice.lucidrains_gpt import Transformer
|
||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from models.tacotron2.text import symbols
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class GptTts(nn.Module):
|
||||
MAX_SYMBOLS_PER_PHRASE = 200
|
||||
NUMBER_SYMBOLS = len(symbols)
|
||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
|
||||
MEL_DICTIONARY_SIZE = 512+3
|
||||
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
|
||||
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
|
||||
|
||||
def __init__(self, layers=8, model_dim=512, heads=8):
|
||||
super().__init__()
|
||||
max_mel_frames = 900 * 1 // 4 # 900 is the max number of MEL frames. The VQVAE outputs 1/8 of the input mel as tokens.
|
||||
|
||||
self.model_dim = model_dim
|
||||
self.max_mel_frames = max_mel_frames
|
||||
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
||||
self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim)
|
||||
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=heads,
|
||||
attn_dropout=.1, ff_dropout=.1)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||
self.mel_head = nn.Linear(model_dim, self.MEL_DICTIONARY_SIZE)
|
||||
|
||||
def forward(self, text_inputs, text_lengths, mel_targets, output_lengths):
|
||||
text_emb = self.text_embedding(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||
mel_emb = self.mel_embedding(mel_targets)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_targets.shape[1], device=mel_targets.device))
|
||||
emb = torch.cat([text_emb, mel_emb], dim=1)
|
||||
enc = self.gpt(emb)
|
||||
|
||||
# Compute logits for text and mel heads
|
||||
text_logits = self.final_norm(enc[:, :text_emb.shape[1]])
|
||||
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
|
||||
text_logits = self.text_head(text_logits)
|
||||
mel_logits = self.mel_head(mel_logits)
|
||||
|
||||
# Compute loss
|
||||
text_targets = text_inputs[:,1:]
|
||||
text_logits = text_logits.permute(0,2,1)[:,:,:-1] # The last element of the logits is unneeded because the input to the transformer contains a <EOS> token for both text and mel.
|
||||
loss_text = F.cross_entropy(text_logits, text_targets, reduction='none')
|
||||
mel_targets = mel_targets[:,1:]
|
||||
mel_logits = mel_logits.permute(0,2,1)[:,:,:-1]
|
||||
loss_mel = F.cross_entropy(mel_logits, mel_targets, reduction='none')
|
||||
|
||||
# Fix up mel_logits so it can go into a VAE decoder as well.
|
||||
mel_codes = torch.argmax(F.softmax(mel_logits, dim=1), dim=1)
|
||||
mel_pad_mask = ~get_mask_from_lengths(output_lengths-1, mel_targets.shape[1])
|
||||
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask, 0)
|
||||
mel_codes = mel_codes[:,:-1] # Strip off <EOS> token too (or padding). The important part is that the output sequence length is identical to the VAE input.
|
||||
extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD
|
||||
mel_codes = mel_codes * extra_mask
|
||||
|
||||
# This class also returns the mel_targets for validation purposes. Format those.
|
||||
mel_targets = mel_targets[:,:-1]
|
||||
mel_targets = mel_targets * (mel_targets < self.MEL_DICTIONARY_SIZE-3)
|
||||
return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets
|
||||
|
||||
def inference(self, text_inputs):
|
||||
b, s = text_inputs.shape
|
||||
text_emb = self.text_embedding(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(s, device=text_inputs.device))
|
||||
|
||||
mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
|
||||
stop_encountered = torch.zeros((b,), device=text_emb.device)
|
||||
while not torch.all(stop_encountered) and len(mel_seq) < self.max_mel_frames:
|
||||
mel_emb = self.mel_embedding(mel_seq)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
emb = torch.cat([text_emb, mel_emb], dim=1)
|
||||
enc = self.gpt(emb)
|
||||
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
|
||||
mel_logits = self.mel_head(mel_logits)
|
||||
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
|
||||
mel_seq = torch.cat([mel_seq, mel_codes[:, -1].unsqueeze(1)], dim=1)
|
||||
stop_encountered = torch.logical_or(stop_encountered, mel_seq[:,-1] == self.MEL_STOP_TOKEN)
|
||||
|
||||
if len(mel_seq) >= self.max_mel_frames:
|
||||
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
|
||||
|
||||
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
|
||||
mel_seq = mel_seq[:, 1:-1] # Remove first and last tokens, which were artificially added for GPT
|
||||
mel_seq = mel_seq * (mel_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens.
|
||||
|
||||
return mel_seq
|
||||
|
||||
def inference_beam_topk(self, text):
|
||||
def topk_sampler(distribution, k):
|
||||
return torch.topk(distribution, k=k, dim=-1)
|
||||
return self.inference_beam(text, topk_sampler)
|
||||
|
||||
def inference_beam_sampled(self, text):
|
||||
def multinomial_sampler(distribution, k):
|
||||
indices = torch.multinomial(distribution, num_samples=k, replacement=False)
|
||||
values = torch.gather(distribution, dim=1, index=indices)
|
||||
class container:
|
||||
def __init__(self, i, v):
|
||||
self.indices = i
|
||||
self.values = v
|
||||
return container(indices, values)
|
||||
return self.inference_beam(text, multinomial_sampler)
|
||||
|
||||
def inference_beam(self, text_inputs, sampler_fn):
|
||||
beam_width = 16
|
||||
temperature = .8
|
||||
|
||||
b, s = text_inputs.shape
|
||||
assert b == 1 # Beam search only works on batches of one.
|
||||
text_emb = self.text_embedding(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(s, device=text_inputs.device))
|
||||
mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
|
||||
probabilities = torch.ones((b,), device=text_emb.device)
|
||||
while len(mel_seq) < self.max_mel_frames:
|
||||
mel_emb = self.mel_embedding(mel_seq)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
if text_emb.shape[0] != mel_emb.shape[0]:
|
||||
text_emb = text_emb.repeat(mel_emb.shape[0], 1, 1)
|
||||
emb = torch.cat([text_emb, mel_emb], dim=1)
|
||||
enc = self.gpt(emb)
|
||||
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
|
||||
mel_logits = self.mel_head(mel_logits)
|
||||
topk = sampler_fn(F.softmax(temperature * mel_logits[:, -1], dim=-1), k=beam_width)
|
||||
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
|
||||
probabilities, sort_indices = torch.sort(probabilities, descending=True)
|
||||
probabilities = probabilities[:beam_width]
|
||||
|
||||
mel_seq = mel_seq.repeat_interleave(beam_width, dim=0)
|
||||
codes = topk.indices.flatten()
|
||||
mel_seq = torch.cat([mel_seq, codes.unsqueeze(1)], dim=1)
|
||||
mel_seq = mel_seq[sort_indices]
|
||||
mel_seq = mel_seq[:beam_width]
|
||||
|
||||
if torch.all(torch.any(mel_seq == self.MEL_STOP_TOKEN, dim=1)):
|
||||
break
|
||||
|
||||
if mel_seq.shape[1] >= self.max_mel_frames:
|
||||
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
|
||||
|
||||
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
|
||||
mel_seq = mel_seq[0, 1:-1].unsqueeze(0) # Pick most likely outcome, remove first and last tokens, which were artificially added for GPT
|
||||
mel_seq = mel_seq * (mel_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens.
|
||||
|
||||
return mel_seq
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def register_gpt_tts(opt_net, opt):
|
||||
return GptTts(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = GptTts()
|
||||
l1, l2, i = gpt(torch.randint(high=24, size=(2,60)),
|
||||
torch.tensor([55,58]),
|
||||
torch.randint(high=512, size=(2,310)),
|
||||
torch.tensor([300,305]))
|
||||
print(i.shape)
|
||||
|
||||
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
|
||||
#print(o.shape)
|
||||
|
||||
|
|
@ -1,236 +0,0 @@
|
|||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
# helpers
|
||||
from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def cast_tuple(val, depth = 1):
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
return val if isinstance(val, tuple) else (val,) * depth
|
||||
|
||||
|
||||
class DivideMax(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
maxes = x.amax(dim = self.dim, keepdim = True)
|
||||
return x / maxes
|
||||
|
||||
|
||||
# https://arxiv.org/abs/2103.17239
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim, depth, fn):
|
||||
super().__init__()
|
||||
if depth <= 18:
|
||||
init_eps = 0.1
|
||||
elif depth > 18 and depth <= 24:
|
||||
init_eps = 1e-5
|
||||
else:
|
||||
init_eps = 1e-6
|
||||
|
||||
scale = torch.zeros(1, 1, dim).fill_(init_eps)
|
||||
self.scale = nn.Parameter(scale)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) * self.scale
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gates = x.chunk(2, dim = -1)
|
||||
return x * F.gelu(gates)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dropout = 0., mult = 4.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, dim * mult * 2),
|
||||
GEGLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim * mult, dim)
|
||||
)
|
||||
|
||||
def forward(self, x, only_last_two_elements=False):
|
||||
if only_last_two_elements:
|
||||
h = x[:, -2:]
|
||||
h = self.net(h)
|
||||
return torch.cat([x[:, :-2], h], dim=1)
|
||||
else:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
|
||||
t = t / alpha
|
||||
t = t - torch.amax(t, dim = dim, keepdim = True)
|
||||
return (t * alpha).softmax(dim = dim)
|
||||
|
||||
|
||||
# classes
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, seq_len, non_causal_sequence_partition = 0, heads = 8, dim_head = 64, dropout = 0., stable = False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.seq_len = seq_len
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.stable = stable
|
||||
self.non_causal_sequence_partition = non_causal_sequence_partition
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None, only_last_two_elements=False):
|
||||
b, n, _, h, device = *x.shape, self.heads, x.device
|
||||
softmax = torch.softmax if not self.stable else stable_softmax
|
||||
|
||||
# TODO: Q and V do not need to be recomputed for existing elements in intermediate_latents is specified. V would need to be cached though.
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
q = q * self.scale
|
||||
|
||||
if only_last_two_elements:
|
||||
q = q[:, :, -2:]
|
||||
assert not exists(mask) # Don't know how to resolve this (currently)
|
||||
|
||||
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
mask_value = max_neg_value(dots)
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b () () j')
|
||||
dots.masked_fill_(~mask, mask_value)
|
||||
del mask
|
||||
|
||||
i, j = dots.shape[-2:]
|
||||
mask = torch.ones(i, j, device = device).triu_(j - i + 1)
|
||||
if self.non_causal_sequence_partition > 0:
|
||||
non_causal_mask = torch.ones((i, j), device=device)
|
||||
non_causal_mask[:, :self.non_causal_sequence_partition] = 0
|
||||
mask = mask * non_causal_mask
|
||||
|
||||
dots.masked_fill_(mask.bool(), mask_value)
|
||||
|
||||
attn = softmax(dots, dim=-1)
|
||||
|
||||
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
seq_len,
|
||||
reversible = False,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
ff_mult = 4,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
sparse_attn = False,
|
||||
stable = False,
|
||||
non_causal_sequence_partition=0,
|
||||
):
|
||||
super().__init__()
|
||||
layers = nn.ModuleList([])
|
||||
sparse_layer = cast_tuple(sparse_attn, depth)
|
||||
|
||||
for ind, sparse_attn in zip(range(depth), sparse_layer):
|
||||
attn = Attention(dim, stable=stable, non_causal_sequence_partition = non_causal_sequence_partition, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
|
||||
|
||||
layers.append(nn.ModuleList([
|
||||
LayerScale(dim, ind + 1, PreNorm(dim, attn)),
|
||||
LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout)))
|
||||
]))
|
||||
|
||||
# TODO: Remove this nonsense. I don't want to support reversible sequences and this is just a mess.
|
||||
execute_type = ReversibleSequence if reversible else SequentialSequence
|
||||
route_attn = ((True, False),) * depth
|
||||
attn_route_map = {'mask': route_attn}
|
||||
|
||||
self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True)
|
||||
self.depth = depth
|
||||
|
||||
def forward(self, x, return_intermediates=False):
|
||||
intermediates = []
|
||||
for attn, ff in self.layers.layers:
|
||||
x_ff = x + checkpoint(attn, x)
|
||||
x = x_ff + ff(x_ff)
|
||||
if return_intermediates:
|
||||
intermediates.append((x_ff, x))
|
||||
if return_intermediates:
|
||||
return x, intermediates
|
||||
else:
|
||||
return x
|
||||
|
||||
def infer_last_two(self, x, prev_intermediates):
|
||||
"""
|
||||
Performs an forward pass only on the last two element in the given sequence (allowing it to attend to all other
|
||||
elements). This is useful for faster autoregressive decoding.
|
||||
|
||||
The last two elements are important because in inference, the last element is the prediction candidate and the
|
||||
second-to-last element is a newly selected element from the autoregressive searching process.
|
||||
"""
|
||||
assert(len(prev_intermediates) == self.depth)
|
||||
new_intermediates = []
|
||||
for (attn, ff), (int_ff, int_out) in zip(self.layers.layers, prev_intermediates):
|
||||
x_ff = attn(x, only_last_two_elements=True)
|
||||
# Note that (x) is now only the last two element in the set. Conjoin it with the int_ff latent to compute the norm.
|
||||
x_ff = x + torch.cat([int_ff[:,:-1], x_ff], dim=1)
|
||||
x = x_ff + ff(x_ff, only_last_two_elements=True)
|
||||
new_intermediates.append((x_ff, x))
|
||||
return x, new_intermediates
|
0
codes/models/gpt_voice/voice_clip.py
Normal file
0
codes/models/gpt_voice/voice_clip.py
Normal file
1
codes/models/lucidrains/dalle/__init__.py
Normal file
1
codes/models/lucidrains/dalle/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
# This directory contains some useful code from https://github.com/lucidrains/DALLE-pytorch/tree/main/dalle_pytorch
|
384
codes/models/lucidrains/dalle/attention.py
Normal file
384
codes/models/lucidrains/dalle/attention.py
Normal file
|
@ -0,0 +1,384 @@
|
|||
from inspect import isfunction
|
||||
from math import ceil
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from rotary_embedding_torch import apply_rotary_emb
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
|
||||
t = t / alpha
|
||||
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
|
||||
return (t * alpha).softmax(dim = dim)
|
||||
|
||||
def apply_pos_emb(pos_emb, qkv):
|
||||
n = qkv[0].shape[-2]
|
||||
pos_emb = pos_emb[..., :n, :]
|
||||
return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv))
|
||||
|
||||
# classes
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.seq_len = seq_len
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.stable = stable
|
||||
self.causal = causal
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None, rotary_pos_emb = None):
|
||||
b, n, _, h, device = *x.shape, self.heads, x.device
|
||||
softmax = torch.softmax if not self.stable else stable_softmax
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
if exists(rotary_pos_emb):
|
||||
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
mask_value = max_neg_value(dots)
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b () () j')
|
||||
dots.masked_fill_(~mask, mask_value)
|
||||
del mask
|
||||
|
||||
if self.causal:
|
||||
i, j = dots.shape[-2:]
|
||||
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
|
||||
attn = softmax(dots, dim=-1)
|
||||
|
||||
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
|
||||
|
||||
class SparseConvCausalAttention(nn.Module):
|
||||
def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
|
||||
super().__init__()
|
||||
assert kernel_size % 2 == 1, 'kernel size must be odd'
|
||||
|
||||
inner_dim = dim_head * heads
|
||||
self.seq_len = seq_len
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.image_size = image_size
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation = dilation
|
||||
|
||||
self.stable = stable
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None, rotary_pos_emb = None):
|
||||
b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
|
||||
softmax = torch.softmax if not self.stable else stable_softmax
|
||||
|
||||
img_seq_len = img_size ** 2
|
||||
text_len = seq_len + 1 - img_seq_len
|
||||
|
||||
# padding
|
||||
|
||||
padding = seq_len - n + 1
|
||||
mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())
|
||||
|
||||
x = F.pad(x, (0, 0, 0, padding), value = 0)
|
||||
mask = mask[:, :text_len]
|
||||
|
||||
# derive query / keys / values
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
|
||||
|
||||
if exists(rotary_pos_emb):
|
||||
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
|
||||
|
||||
q *= self.scale
|
||||
|
||||
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
|
||||
|
||||
# text attention
|
||||
|
||||
dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
|
||||
mask_value = max_neg_value(dots_text)
|
||||
|
||||
i, j = dots_text.shape[-2:]
|
||||
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
|
||||
dots_text.masked_fill_(text_causal_mask, mask_value)
|
||||
|
||||
attn_text = softmax(dots_text, dim = -1)
|
||||
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
|
||||
|
||||
# image attention
|
||||
|
||||
effective_kernel_size = (kernel_size - 1) * dilation + 1
|
||||
padding = effective_kernel_size // 2
|
||||
|
||||
k_img, v_img = map(lambda t: rearrange(t, 'b (h w) c -> b c h w', h = img_size), (k_img, v_img))
|
||||
k_img, v_img = map(lambda t: F.unfold(t, kernel_size, padding = padding, dilation = dilation), (k_img, v_img))
|
||||
k_img, v_img = map(lambda t: rearrange(t, 'b (d j) i -> b i j d', j = kernel_size ** 2), (k_img, v_img))
|
||||
|
||||
# let image attend to all of text
|
||||
|
||||
dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img)
|
||||
dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text)
|
||||
|
||||
# calculate causal attention for local convolution
|
||||
|
||||
i, j = dots_image.shape[-2:]
|
||||
img_seq = torch.arange(img_seq_len, device = device)
|
||||
k_img_indices = rearrange(img_seq.float(), '(h w) -> () () h w', h = img_size)
|
||||
k_img_indices = F.pad(k_img_indices, (padding,) * 4, value = img_seq_len) # padding set to be max, so it is never attended to
|
||||
k_img_indices = F.unfold(k_img_indices, kernel_size, dilation = dilation)
|
||||
k_img_indices = rearrange(k_img_indices, 'b j i -> b i j')
|
||||
|
||||
# mask image attention
|
||||
|
||||
q_img_indices = rearrange(img_seq, 'i -> () i ()')
|
||||
causal_mask = q_img_indices < k_img_indices
|
||||
|
||||
# concat text mask with image causal mask
|
||||
|
||||
causal_mask = repeat(causal_mask, '() i j -> b i j', b = b * h)
|
||||
mask = repeat(mask, 'b j -> (b h) i j', i = i, h = h)
|
||||
mask = torch.cat((~mask, causal_mask), dim = -1)
|
||||
|
||||
# image can attend to all of text
|
||||
|
||||
dots = torch.cat((dots_image_to_text, dots_image), dim = -1)
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
|
||||
attn = softmax(dots, dim = -1)
|
||||
|
||||
# aggregate
|
||||
|
||||
attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:]
|
||||
|
||||
out_image_to_image = einsum('b i j, b i j d -> b i d', attn_image, v_img)
|
||||
out_image_to_text = einsum('b i j, b j d -> b i d', attn_image_to_text, v_text)
|
||||
|
||||
out_image = out_image_to_image + out_image_to_text
|
||||
|
||||
# combine attended values for both text and image
|
||||
|
||||
out = torch.cat((out_text, out_image), dim = 1)
|
||||
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
|
||||
out = self.to_out(out)
|
||||
return out[:, :n]
|
||||
|
||||
# sparse axial causal attention
|
||||
|
||||
class SparseAxialCausalAttention(nn.Module):
|
||||
def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
|
||||
super().__init__()
|
||||
assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)'
|
||||
self.axis = axis
|
||||
|
||||
inner_dim = dim_head * heads
|
||||
self.seq_len = seq_len
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.image_size = image_size
|
||||
|
||||
self.stable = stable
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None, rotary_pos_emb = None):
|
||||
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
|
||||
softmax = torch.softmax if not self.stable else stable_softmax
|
||||
|
||||
img_seq_len = img_size ** 2
|
||||
text_len = seq_len + 1 - img_seq_len
|
||||
|
||||
# padding
|
||||
|
||||
padding = seq_len - n + 1
|
||||
mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())
|
||||
|
||||
x = F.pad(x, (0, 0, 0, padding), value = 0)
|
||||
mask = mask[:, :text_len]
|
||||
|
||||
# derive queries / keys / values
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
|
||||
|
||||
if exists(rotary_pos_emb):
|
||||
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
|
||||
|
||||
q *= self.scale
|
||||
|
||||
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
|
||||
|
||||
# text attention
|
||||
|
||||
dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
|
||||
mask_value = max_neg_value(dots_text)
|
||||
|
||||
i, j = dots_text.shape[-2:]
|
||||
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
|
||||
dots_text.masked_fill_(text_causal_mask, mask_value)
|
||||
|
||||
attn_text = softmax(dots_text, dim = -1)
|
||||
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
|
||||
|
||||
# image attention
|
||||
|
||||
split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c'
|
||||
merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d'
|
||||
|
||||
# split out axis
|
||||
|
||||
q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img))
|
||||
|
||||
# similarity
|
||||
|
||||
dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img)
|
||||
dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)
|
||||
|
||||
dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)
|
||||
|
||||
# mask so image has full attention to text, but causal along axis
|
||||
|
||||
bh, x, i, j = dots.shape
|
||||
causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool()
|
||||
causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x)
|
||||
|
||||
mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i)
|
||||
mask = torch.cat((~mask, causal_mask), dim = -1)
|
||||
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
|
||||
# attention.
|
||||
|
||||
attn = softmax(dots, dim = -1)
|
||||
|
||||
# aggregate
|
||||
|
||||
attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]
|
||||
|
||||
out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img)
|
||||
out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text)
|
||||
|
||||
out_image = out_image_to_image + out_image_to_text
|
||||
|
||||
# merge back axis
|
||||
|
||||
out_image = rearrange(out_image, merge_axis_einops, x = img_size)
|
||||
|
||||
# combine attended values for both text and image
|
||||
|
||||
out = torch.cat((out_text, out_image), dim = 1)
|
||||
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
|
||||
out = self.to_out(out)
|
||||
return out[:, :n]
|
||||
|
||||
# microsoft sparse attention CUDA kernel
|
||||
|
||||
class SparseAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
block_size = 16,
|
||||
text_seq_len = 256,
|
||||
num_random_blocks = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig
|
||||
self.block_size = block_size
|
||||
|
||||
num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4)
|
||||
global_block_indices = list(range(ceil(text_seq_len / block_size)))
|
||||
|
||||
self.attn_fn = SparseSelfAttention(
|
||||
sparsity_config = VariableSparsityConfig(
|
||||
num_heads = self.heads,
|
||||
block = self.block_size,
|
||||
num_random_blocks = num_random_blocks,
|
||||
global_block_indices = global_block_indices,
|
||||
attention = 'unidirectional' if self.causal else 'bidirectional'
|
||||
),
|
||||
max_seq_length = self.seq_len,
|
||||
attn_mask_mode = 'add'
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None, rotary_pos_emb = None):
|
||||
b, n, _, h, device = *x.shape, self.heads, x.device
|
||||
remainder = n % self.block_size
|
||||
mask = default(mask, lambda: torch.ones(b, n, device = device).bool())
|
||||
|
||||
if remainder > 0:
|
||||
padding = self.block_size - remainder
|
||||
x = F.pad(x, (0, 0, 0, padding), value = 0)
|
||||
mask = F.pad(mask, (0, padding), value = False)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
if exists(rotary_pos_emb):
|
||||
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
|
||||
|
||||
key_pad_mask = None
|
||||
if exists(mask):
|
||||
key_pad_mask = ~mask
|
||||
|
||||
attn_mask = None
|
||||
if self.causal:
|
||||
i, j = q.shape[-2], k.shape[-2]
|
||||
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
|
||||
attn_mask = torch.zeros(i, j, device = device).to(q)
|
||||
mask_value = max_neg_value(q) / 2
|
||||
attn_mask.masked_fill_(mask, mask_value)
|
||||
|
||||
out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
return out[:, :n]
|
|
@ -1,12 +1,10 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from operator import itemgetter
|
||||
from torch.autograd.function import Function
|
||||
from torch.utils.checkpoint import get_device_states, set_device_states
|
||||
|
||||
# for routing arguments into the functions of the reversible layer
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
def route_args(router, args, depth):
|
||||
routed_args = [(dict(), dict()) for _ in range(depth)]
|
||||
matched_keys = [key for key in args.keys() if key in router]
|
||||
|
@ -126,25 +124,20 @@ class _ReversibleFunction(Function):
|
|||
return dy, None, None
|
||||
|
||||
class SequentialSequence(nn.Module):
|
||||
def __init__(self, layers, args_route = {}, layer_dropout = 0., checkpoint=False):
|
||||
def __init__(self, layers, args_route = {}, layer_dropout = 0.):
|
||||
super().__init__()
|
||||
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
|
||||
self.layers = layers
|
||||
self.args_route = args_route
|
||||
self.layer_dropout = layer_dropout
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
args = route_args(self.args_route, kwargs, len(self.layers))
|
||||
layers_and_args = list(zip(self.layers, args))
|
||||
|
||||
for (f, g), (f_args, g_args) in layers_and_args:
|
||||
if self.checkpoint:
|
||||
x = x + f(x, **f_args)
|
||||
x = x + g(x, **g_args)
|
||||
else:
|
||||
x = x + checkpoint(f, x, **f_args)
|
||||
x = x + checkpoint(g, x, **g_args)
|
||||
x = x + f(x, **f_args)
|
||||
x = x + g(x, **g_args)
|
||||
return x
|
||||
|
||||
class ReversibleSequence(nn.Module):
|
231
codes/models/lucidrains/dalle/transformer.py
Normal file
231
codes/models/lucidrains/dalle/transformer.py
Normal file
|
@ -0,0 +1,231 @@
|
|||
from functools import partial
|
||||
from itertools import islice, cycle
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from models.lucidrains.dalle.reversible import ReversibleSequence, SequentialSequence
|
||||
from models.lucidrains.dalle.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding, broadcat
|
||||
from g_mlp_pytorch import gMLPBlock
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, depth = 1):
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
return val if isinstance(val, tuple) else (val,) * depth
|
||||
|
||||
# classes
|
||||
|
||||
class DivideMax(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
maxes = x.amax(dim = self.dim, keepdim = True).detach()
|
||||
return x / maxes
|
||||
|
||||
# https://arxiv.org/abs/2103.17239
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim, depth, fn):
|
||||
super().__init__()
|
||||
if depth <= 18:
|
||||
init_eps = 0.1
|
||||
elif depth > 18 and depth <= 24:
|
||||
init_eps = 1e-5
|
||||
else:
|
||||
init_eps = 1e-6
|
||||
|
||||
scale = torch.zeros(1, 1, dim).fill_(init_eps)
|
||||
self.scale = nn.Parameter(scale)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) * self.scale
|
||||
|
||||
# layer norm
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn, sandwich = False):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x = self.norm(x)
|
||||
x = self.fn(x, **kwargs)
|
||||
return self.norm_out(x)
|
||||
|
||||
# feed forward
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gates = x.chunk(2, dim = -1)
|
||||
return x * F.gelu(gates)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dropout = 0., mult = 4.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, dim * mult * 2),
|
||||
GEGLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim * mult, dim)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# token shift classes
|
||||
|
||||
class PreShiftToken(nn.Module):
|
||||
def __init__(self, fn, image_size, seq_len):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.image_size = image_size
|
||||
self.seq_len = seq_len
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
n = x.shape[1]
|
||||
seq_len, image_size = self.seq_len, self.image_size
|
||||
img_seq_len = image_size ** 2
|
||||
text_len = seq_len - img_seq_len + 1
|
||||
padding = seq_len - n + 1
|
||||
|
||||
# get text and image tokens
|
||||
|
||||
x_text, x_img = x[:, :text_len], x[:, text_len:]
|
||||
x_img = F.pad(x_img, (0, 0, 0, padding))
|
||||
x_img = rearrange(x_img, 'b (h w) d -> b h w d', h = image_size)
|
||||
|
||||
# shift 1 from the left for text tokens
|
||||
|
||||
x_text_shift, x_text_pass = x_text.chunk(2, dim = -1)
|
||||
x_text_shift = F.pad(x_text_shift, (0, 0, 1, -1))
|
||||
x_text = torch.cat((x_text_shift, x_text_pass), dim = -1)
|
||||
|
||||
# shift from top, left for image tokens
|
||||
|
||||
x_img_shift_top, x_img_shift_left, *x_img_pass = x_img.chunk(4, dim = -1)
|
||||
x_img_shift_left = F.pad(x_img_shift_left, (0, 0, 1, -1))
|
||||
x_img_shift_top = F.pad(x_img_shift_top, (0, 0, 0, 0, 1, -1))
|
||||
x_img = torch.cat((x_img_shift_top, x_img_shift_left, *x_img_pass), dim = -1)
|
||||
|
||||
# merge text and image sequence back together
|
||||
|
||||
x_img = rearrange(x_img, 'b h w d -> b (h w) d')
|
||||
x = torch.cat((x_text, x_img[:, :-padding]), dim = 1)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
# main transformer class
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
seq_len,
|
||||
reversible = False,
|
||||
causal = True,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
ff_mult = 4,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
attn_types = None,
|
||||
image_fmap_size = None,
|
||||
sparse_attn = False,
|
||||
stable = False,
|
||||
sandwich_norm = False,
|
||||
shift_tokens = False,
|
||||
rotary_emb = True
|
||||
):
|
||||
super().__init__()
|
||||
layers = nn.ModuleList([])
|
||||
sparse_layer = cast_tuple(sparse_attn, depth)
|
||||
|
||||
attn_types = default(attn_types, ('full',))
|
||||
attn_types = cast_tuple(attn_types)
|
||||
attn_type_layer = islice(cycle(attn_types), depth)
|
||||
|
||||
for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer):
|
||||
if attn_type == 'full':
|
||||
attn_class = partial(Attention, stable = stable)
|
||||
elif attn_type == 'sparse':
|
||||
attn_class = SparseAttention
|
||||
elif attn_type == 'axial_row':
|
||||
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable)
|
||||
elif attn_type == 'axial_col':
|
||||
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)
|
||||
elif attn_type == 'conv_like':
|
||||
attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable)
|
||||
elif attn_type == 'mlp':
|
||||
attn_class = partial(gMLPBlock, seq_len = seq_len)
|
||||
else:
|
||||
raise ValueError(f'attention type "{attn_type}" is not valid')
|
||||
|
||||
if attn_type != 'mlp':
|
||||
attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
|
||||
else:
|
||||
attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4)
|
||||
|
||||
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
|
||||
|
||||
if shift_tokens:
|
||||
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))
|
||||
|
||||
layers.append(nn.ModuleList([
|
||||
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
|
||||
LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm))
|
||||
]))
|
||||
|
||||
execute_type = ReversibleSequence if reversible else SequentialSequence
|
||||
route_attn = ((True, False),) * depth
|
||||
attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn}
|
||||
|
||||
self.layers = execute_type(layers, args_route = attn_route_map)
|
||||
|
||||
# generate positional embeddings for rotary
|
||||
|
||||
pos_emb = None
|
||||
if rotary_emb:
|
||||
assert 'mlp' not in attn_types, 'you cannot use gMLPs if rotary embedding is turned on'
|
||||
|
||||
rot_dim = dim_head // 3
|
||||
img_seq_len = (image_fmap_size ** 2)
|
||||
text_len = seq_len - img_seq_len + 1
|
||||
|
||||
text_pos_emb = RotaryEmbedding(dim = rot_dim)
|
||||
img_axial_pos_emb = RotaryEmbedding(dim = rot_dim, freqs_for = 'pixel')
|
||||
|
||||
text_freqs = text_pos_emb(torch.arange(text_len))
|
||||
img_to_text_freqs = text_pos_emb(torch.full((img_seq_len,), 8192)) # image is given a position far away from text
|
||||
text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim = 0)
|
||||
|
||||
img_freqs_axial = img_axial_pos_emb(torch.linspace(-1, 1, steps = image_fmap_size))
|
||||
img_freqs = broadcat((rearrange(img_freqs_axial, 'i d -> i () d'), rearrange(img_freqs_axial, 'j d -> () j d')), dim = -1)
|
||||
img_freqs = rearrange(img_freqs, 'h w d -> (h w) d')
|
||||
|
||||
text_axial_freqs = img_axial_pos_emb(torch.full((text_len,), -10.)) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1]
|
||||
text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim = -1)
|
||||
img_freqs = torch.cat((text_axial_freqs, img_freqs), dim = 0)
|
||||
|
||||
pos_emb = torch.cat((text_freqs, img_freqs), dim = -1)
|
||||
pos_emb = rearrange(pos_emb, 'n d -> () n d')
|
||||
|
||||
self.register_buffer('pos_emb', pos_emb)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs)
|
|
@ -11,18 +11,15 @@ munch
|
|||
tqdm
|
||||
scp
|
||||
tensorboard
|
||||
linear_attention_transformer
|
||||
orjson
|
||||
einops
|
||||
lambda-networks
|
||||
vector-quantize-pytorch
|
||||
|
||||
# For image generation stuff
|
||||
opencv-python
|
||||
kornia
|
||||
pytorch_ssim
|
||||
gsa-pytorch
|
||||
vector_quantize_pytorch
|
||||
pytorch_fid==0.1.1
|
||||
|
||||
# For audio generation stuff
|
||||
|
@ -31,4 +28,15 @@ librosa==0.6.0
|
|||
Unidecode==1.0.22
|
||||
tgt == 1.4.4
|
||||
pyworld == 0.2.10
|
||||
audio2numpy
|
||||
audio2numpy
|
||||
|
||||
# For text stuff
|
||||
transformers
|
||||
tokenizers
|
||||
|
||||
# lucidrains stuff
|
||||
vector_quantize_pytorch
|
||||
linear_attention_transformer
|
||||
rotary-embedding-torch
|
||||
axial_positional_embedding
|
||||
g-mlp-pytorch
|
Loading…
Reference in New Issue
Block a user