From 8fb147e8ab6a6cea25a8a28f7c998174c1da6e08 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 4 Feb 2022 11:00:15 -0700 Subject: [PATCH] add an autoregressive ctc code generator --- codes/data/audio/fast_paired_dataset.py | 2 +- codes/models/gpt_voice/ctc_code_generator2.py | 94 +++++++++++++++++++ codes/train.py | 2 +- 3 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 codes/models/gpt_voice/ctc_code_generator2.py diff --git a/codes/data/audio/fast_paired_dataset.py b/codes/data/audio/fast_paired_dataset.py index a3a09059..ed405608 100644 --- a/codes/data/audio/fast_paired_dataset.py +++ b/codes/data/audio/fast_paired_dataset.py @@ -118,7 +118,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): codes, repeats, pads = [], [], [0] for val, group in grouped: if val == 0: - pads[-1] = len(list(group)) + pads[-1] = len(list(group)) # This is a very important distinction! It means the padding belongs to the character proceeding it. else: codes.append(val) repeats.append(len(list(group))) diff --git a/codes/models/gpt_voice/ctc_code_generator2.py b/codes/models/gpt_voice/ctc_code_generator2.py new file mode 100644 index 00000000..1d6db586 --- /dev/null +++ b/codes/models/gpt_voice/ctc_code_generator2.py @@ -0,0 +1,94 @@ +import functools + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import T5Config, T5Model +from x_transformers import Encoder, XTransformer + +from models.gpt_voice.transformer_builders import null_position_embeddings +from models.gpt_voice.unet_diffusion_tts6 import CheckpointedLayer +from models.gpt_voice.unified_voice2 import ConditioningEncoder +from trainer.networks import register_model +from utils.util import opt_get + + +class CtcCodeGenerator(nn.Module): + def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=120, max_repeat=30, checkpointing=True): + super().__init__() + self.max_pad = max_pad + self.max_repeat = max_repeat + self.start_token = (self.max_repeat+1)*(self.max_pad+1)+1 + self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads) + self.embedding = nn.Embedding(ctc_codes, model_dim) + self.dec_embedding = nn.Embedding(self.start_token+1, model_dim) + self.config = T5Config( + vocab_size=1, # T5 embedding will be removed and replaced with custom embedding. + d_model=model_dim, + d_kv=model_dim//num_heads, + d_ff=model_dim*4, + num_layers=layers, + num_heads=num_heads, + dropout_rate=dropout, + feed_forward_proj='gated-gelu', + use_cache=not checkpointing, + gradient_checkpointing=checkpointing + ) + self.transformer = T5Model(self.config) + del self.transformer.encoder.embed_tokens + del self.transformer.decoder.embed_tokens + self.transformer.encoder.embed_tokens = functools.partial(null_position_embeddings, dim=model_dim) + self.transformer.decoder.embed_tokens = functools.partial(null_position_embeddings, dim=model_dim) + self.output_layer = nn.Linear(model_dim, self.start_token+1) + + + def forward(self, conditioning_input, codes, pads, repeats, unpadded_lengths): + max_len = unpadded_lengths.max() + codes = codes[:, :max_len] + pads = pads[:, :max_len] + repeats = repeats[:, :max_len] + + if pads.max() > self.max_pad: + print(f"Got unexpectedly long pads. Max: {pads.max()}, {pads}") + pads = torch.clip(pads, 0, self.max_pad) + if repeats.max() > self.max_repeat: + print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}") + repeats = torch.clip(repeats, 0, self.max_repeat) + assert codes.max() < 36, codes.max() + + conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input + conds = [] + for j in range(conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + h = torch.cat([conds, self.embedding(codes)], dim=1) + + labels = pads + repeats * self.max_pad + 1 + for i in range(unpadded_lengths.shape[0]): + labels[i, unpadded_lengths[i]:] = 0 + labels_in = F.pad(labels, (1,0), value=self.start_token) + h_dec = self.dec_embedding(labels_in) + + h = self.transformer(inputs_embeds=h, decoder_inputs_embeds=h_dec).last_hidden_state + logits = self.output_layer(h) + logits = logits.permute(0,2,1)[:,:,:-1] # Strip off the last token. There is no "stop" token here, so this is just an irrelevant prediction on some future that doesn't actually exist. + loss = F.cross_entropy(logits, labels, reduction='none') + + # Ignore the first predictions of the sequences. This corresponds to the padding for the first CTC character, which is pretty much random and cannot be predicted. + #loss = loss[1:].mean() + return loss + + +@register_model +def register_ctc_code_generator2(opt_net, opt): + return CtcCodeGenerator(**opt_get(opt_net, ['kwargs'], {})) + + +if __name__ == '__main__': + model = CtcCodeGenerator() + conds = torch.randn(4,2,80,600) + inps = torch.randint(0,36, (4, 300)) + pads = torch.randint(0,100, (4,300)) + repeats = torch.randint(0,20, (4,300)) + loss = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30])) + print(loss.shape) \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index 63e2c4a0..be9cabac 100644 --- a/codes/train.py +++ b/codes/train.py @@ -299,7 +299,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_tts6_upsample.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()