diff --git a/codes/models/gpt_voice/gpt_asr.py b/codes/models/gpt_voice/gpt_asr.py new file mode 100644 index 00000000..da606ee6 --- /dev/null +++ b/codes/models/gpt_voice/gpt_asr.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from munch import munchify + +from models.gpt_voice.lucidrains_gpt import Transformer +from models.tacotron2.taco_utils import get_mask_from_lengths +from models.tacotron2.text import symbols +from trainer.networks import register_model +from utils.util import opt_get + + +class ResBlock(nn.Module): + def __init__(self, chan): + super().__init__() + self.net = nn.Sequential( + nn.Conv1d(chan, chan, kernel_size=5, padding = 2), + nn.BatchNorm1d(chan), + nn.ReLU(), + nn.Conv1d(chan, chan, kernel_size=5, padding = 2), + nn.BatchNorm1d(chan) + ) + + def forward(self, x): + return F.relu(self.net(x) + x) + + +class MelEncoder(nn.Module): + def __init__(self, channels, mel_channels=80): + super().__init__() + self.channels = channels + self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=7, padding=3), + ResBlock(channels//4), + ResBlock(channels//4), + nn.Conv1d(channels//4, channels//2, kernel_size=5, stride=2, padding=2), + nn.BatchNorm1d(channels//2), + nn.ReLU(), + ResBlock(channels//2), + ResBlock(channels//2), + ResBlock(channels//2), + nn.Conv1d(channels//2, channels, kernel_size=5, stride=2, padding=2), + ResBlock(channels), + ResBlock(channels), + ResBlock(channels) + ) + + def forward(self, x): + return self.encoder(x) + + +class GptAsr(nn.Module): + MAX_SYMBOLS_PER_PHRASE = 200 + MAX_MEL_FRAMES = 1000 // 4 + NUMBER_SYMBOLS = len(symbols) + NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + + def __init__(self, layers=8, model_dim=512, heads=8): + super().__init__() + + self.model_dim = model_dim + self.max_mel_frames = self.MAX_MEL_FRAMES + self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim) + self.mel_encoder = MelEncoder(model_dim) + self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim) + self.mel_pos_embedding = nn.Embedding(self.MAX_MEL_FRAMES, model_dim) + self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+self.MAX_MEL_FRAMES, heads=heads, + attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.MAX_MEL_FRAMES) + + self.final_norm = nn.LayerNorm(model_dim) + self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) + + def forward(self, mel_inputs, text_targets): + text_targets = F.pad(text_targets, (0, self.MAX_SYMBOLS_PER_PHRASE-text_targets.shape[1])) + text_emb = self.text_embedding(text_targets) + text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device)) + mel_emb = self.mel_encoder(mel_inputs) + mel_emb = F.pad(mel_emb, (0, self.MAX_MEL_FRAMES-mel_emb.shape[-1])) + mel_emb = mel_emb.permute(0,2,1).contiguous() + mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + emb = torch.cat([mel_emb, text_emb], dim=1) + + enc = self.gpt(emb) + + # Compute loss + text_logits = self.final_norm(enc[:, self.MAX_MEL_FRAMES:]) + text_logits = self.text_head(text_logits) + text_logits = text_logits.permute(0,2,1) + loss_text = F.cross_entropy(text_logits[:,:,1:], text_targets[:,:-1].long()) + + return loss_text.mean() + + +@register_model +def register_gpt_asr(opt_net, opt): + return GptAsr(**opt_get(opt_net, ['kwargs'], {})) + + +if __name__ == '__main__': + gpt = GptAsr() + l = gpt(torch.randn(2,80,800), + torch.randint(high=len(symbols), size=(2,180))) + print(l.shape) + + #o = gpt.infer(torch.randint(high=24, size=(2,60))) + #print(o.shape) + + diff --git a/codes/models/gpt_voice/lucidrains_gpt.py b/codes/models/gpt_voice/lucidrains_gpt.py index f71a4c88..ed766384 100644 --- a/codes/models/gpt_voice/lucidrains_gpt.py +++ b/codes/models/gpt_voice/lucidrains_gpt.py @@ -108,7 +108,7 @@ def stable_softmax(t, dim = -1, alpha = 32 ** 2): # classes class Attention(nn.Module): - def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False): + def __init__(self, dim, seq_len, non_causal_sequence_partition = 0, heads = 8, dim_head = 64, dropout = 0., stable = False): super().__init__() inner_dim = dim_head * heads self.heads = heads @@ -116,7 +116,7 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 self.stable = stable - self.causal = causal + self.non_causal_sequence_partition = non_causal_sequence_partition self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( @@ -141,10 +141,14 @@ class Attention(nn.Module): dots.masked_fill_(~mask, mask_value) del mask - if self.causal: - i, j = dots.shape[-2:] - mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() - dots.masked_fill_(mask, mask_value) + i, j = dots.shape[-2:] + mask = torch.ones(i, j, device = device).triu_(j - i + 1) + if self.non_causal_sequence_partition > 0: + non_causal_mask = torch.ones((i, j), device=device) + non_causal_mask[:, :self.non_causal_sequence_partition] = 0 + mask = mask * non_causal_mask + + dots.masked_fill_(mask.bool(), mask_value) attn = softmax(dots, dim=-1) @@ -162,21 +166,21 @@ class Transformer(nn.Module): depth, seq_len, reversible = False, - causal = True, heads = 8, dim_head = 64, ff_mult = 4, attn_dropout = 0., ff_dropout = 0., sparse_attn = False, - stable = False + stable = False, + non_causal_sequence_partition=0, ): super().__init__() layers = nn.ModuleList([]) sparse_layer = cast_tuple(sparse_attn, depth) for ind, sparse_attn in zip(range(depth), sparse_layer): - attn = Attention(dim, stable=stable, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) + attn = Attention(dim, stable=stable, non_causal_sequence_partition = non_causal_sequence_partition, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) layers.append(nn.ModuleList([ LayerScale(dim, ind + 1, PreNorm(dim, attn)), diff --git a/codes/train.py b/codes/train.py index b5208f82..d2aa5532 100644 --- a/codes/train.py +++ b/codes/train.py @@ -282,7 +282,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_mozcv.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mozcv.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index 0a4ffebc..3eeb8749 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -60,7 +60,7 @@ class GeneratorInjector(Injector): results = method(*params) new_state = {} if isinstance(self.output, list): - # Only dereference tuples or lists, not tensors. + # Only dereference tuples or lists, not tensors. IF YOU REACH THIS ERROR, REMOVE THE BRACES AROUND YOUR OUTPUTS IN THE YAML CONFIG assert isinstance(results, list) or isinstance(results, tuple) for i, k in enumerate(self.output): new_state[k] = results[i]