From 4249681c4b18e92ba20f8d69cb4b5196861c4ead Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Thu, 3 Feb 2022 19:58:54 -0700
Subject: [PATCH] Mods to support a autoregressive CTC code generator

---
 codes/data/audio/fast_paired_dataset.py      | 54 ++++++++++--
 codes/models/gpt_voice/ctc_code_generator.py | 87 ++++++++++++++++++++
 codes/train.py                               |  2 +-
 3 files changed, 135 insertions(+), 8 deletions(-)
 create mode 100644 codes/models/gpt_voice/ctc_code_generator.py

diff --git a/codes/data/audio/fast_paired_dataset.py b/codes/data/audio/fast_paired_dataset.py
index 638ec81a..a3a09059 100644
--- a/codes/data/audio/fast_paired_dataset.py
+++ b/codes/data/audio/fast_paired_dataset.py
@@ -1,8 +1,8 @@
 import hashlib
 import os
-import os
 import random
 import sys
+from itertools import groupby
 
 import torch
 import torch.nn.functional as F
@@ -12,8 +12,6 @@ from tqdm import tqdm
 
 from data.audio.paired_voice_audio_dataset import CharacterTokenizer
 from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
-from models.tacotron2.taco_utils import load_filepaths_and_text
-from models.tacotron2.text import text_to_sequence, sequence_to_text
 from utils.util import opt_get
 
 
@@ -53,6 +51,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
         self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
         self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
         self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
+        self.produce_ctc_metadata = opt_get(hparams, ['produce_ctc_metadata'], False)
         self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
         self.text_cleaners = hparams.text_cleaners
         self.sample_rate = hparams.sample_rate
@@ -114,6 +113,39 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
                 print(f"error parsing random offset: {sys.exc_info()}")
         return self.load_random_line(depth=depth+1)  # On failure, just recurse and try again.
 
+    def get_ctc_metadata(self, codes):
+        grouped = groupby(codes.tolist())
+        codes, repeats, pads = [], [], [0]
+        for val, group in grouped:
+            if val == 0:
+                pads[-1] = len(list(group))
+            else:
+                codes.append(val)
+                repeats.append(len(list(group)))
+                pads.append(0)
+
+        codes = torch.tensor(codes)
+        # These clip values are sane maximum values which I did not see in the datasets I have access to.
+        repeats = torch.clip(torch.tensor(repeats), max=30)
+        pads = torch.clip(torch.tensor(pads[:-1]), max=120)
+
+        # Pad or clip the codes to get them to exactly self.max_text_len
+        orig_lens = codes.shape[0]
+        if codes.shape[0] < self.max_text_len:
+            gap = self.max_text_len - codes.shape[0]
+            codes = F.pad(codes, (0, gap))
+            repeats = F.pad(repeats, (0, gap))
+            pads = F.pad(pads, (0, gap))
+        elif codes.shape[0] > self.max_text_len:
+            codes = codes[:self.max_text_len]
+            repeats = codes[:self.max_text_len]
+            pads = pads[:self.max_text_len]
+        return {
+            'ctc_raw_codes': codes,
+            'ctc_pads': pads,
+            'ctc_repeats': repeats,
+            'ctc_raw_lengths': orig_lens,
+        }
 
     def __getitem__(self, index):
         self.skipped_items += 1
@@ -130,7 +162,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
             if self.debug_failures:
                 print(f"error loading {apt[0]} {sys.exc_info()}")
             return self[(index+1) % len(self)]
-        aligned_codes = apt[2]
+        raw_codes = apt[2]
+        aligned_codes = raw_codes
 
         actually_skipped_items = self.skipped_items
         self.skipped_items = 0
@@ -166,6 +199,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
         if self.load_conditioning:
             res['conditioning'] = cond
             res['conditioning_contains_self'] = cond_is_self
+        if self.produce_ctc_metadata:
+            res.update(self.get_ctc_metadata(raw_codes))
         return res
 
     def __len__(self):
@@ -223,6 +258,7 @@ if __name__ == '__main__':
         'conditioning_length': 44000,
         'use_bpe_tokenizer': False,
         'load_aligned_codes': True,
+        'produce_ctc_metadata': True,
     }
     from data import create_dataset, create_dataloader
 
@@ -236,10 +272,14 @@ if __name__ == '__main__':
     dl = create_dataloader(ds, params, collate_fn=c)
     i = 0
     m = None
+    max_pads, max_repeats = 0, 0
     for i, b in tqdm(enumerate(dl)):
         for ib in range(batch_sz):
+            max_pads = max(max_pads, b['ctc_pads'].max())
+            max_repeats = max(max_repeats, b['ctc_repeats'].max())
             print(f'{i} {ib} {b["real_text"][ib]}')
-            save(b, i, ib, 'wav')
-        if i > 5:
-            break
+            #save(b, i, ib, 'wav')
+        #if i > 5:
+        #    break
+    print(max_pads, max_repeats)
 
diff --git a/codes/models/gpt_voice/ctc_code_generator.py b/codes/models/gpt_voice/ctc_code_generator.py
new file mode 100644
index 00000000..5dfce669
--- /dev/null
+++ b/codes/models/gpt_voice/ctc_code_generator.py
@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from x_transformers import Encoder, XTransformer
+
+from models.gpt_voice.unet_diffusion_tts6 import CheckpointedLayer
+from trainer.networks import register_model
+from utils.util import opt_get
+
+
+class CheckpointedXTransformerEncoder(nn.Module):
+    """
+    Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
+    to channels-last that XTransformer expects.
+    """
+    def __init__(self, **xtransformer_kwargs):
+        super().__init__()
+        self.transformer = XTransformer(**xtransformer_kwargs)
+
+        for xform in [self.transformer.encoder, self.transformer.decoder.net]:
+            for i in range(len(xform.attn_layers.layers)):
+                n, b, r = xform.attn_layers.layers[i]
+                xform.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
+
+    def forward(self, *args, **kwargs):
+        return self.transformer(*args, **kwargs)
+
+
+class CtcCodeGenerator(nn.Module):
+    def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=120, max_repeat=30):
+        super().__init__()
+        self.max_pad = max_pad
+        self.max_repeat = max_repeat
+        self.transformer = XTransformer(
+            dim=model_dim,
+            enc_depth=layers,
+            dec_depth=layers,
+            enc_heads=num_heads,
+            dec_heads=num_heads,
+            enc_num_tokens=ctc_codes,
+            dec_num_tokens=(max_pad+1)*(max_repeat+1),
+            enc_max_seq_len=-1,
+            dec_max_seq_len=-1,
+
+            enc_ff_dropout=dropout,
+            enc_attn_dropout=dropout,
+            enc_use_rmsnorm=True,
+            enc_ff_glu=True,
+            enc_rotary_pos_emb=True,
+            dec_ff_dropout=dropout,
+            dec_attn_dropout=dropout,
+            dec_use_rmsnorm=True,
+            dec_ff_glu=True,
+            dec_rotary_pos_emb=True)
+
+    def forward(self, codes, pads, repeats, unpadded_lengths=None):
+        if unpadded_lengths is not None:
+            max_len = unpadded_lengths.max()
+            codes = codes[:, :max_len]
+            pads = pads[:, :max_len]
+            repeats = repeats[:, :max_len]
+
+        if pads.max() > self.max_pad:
+            print(f"Got unexpectedly long pads. Max: {pads.max()}, {pads}")
+            pads = torch.clip(pads, 0, self.max_pad)
+        if repeats.max() > self.max_repeat:
+            print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}")
+            repeats = torch.clip(repeats, 0, self.max_repeat)
+        assert codes.max() < 36, codes.max()
+
+        labels = pads + repeats * self.max_pad
+        loss = self.transformer(codes, labels)
+        return loss
+
+
+@register_model
+def register_ctc_code_generator(opt_net, opt):
+    return CtcCodeGenerator(**opt_get(opt_net, ['kwargs'], {}))
+
+
+if __name__ == '__main__':
+    model = CtcCodeGenerator()
+    inps = torch.randint(0,36, (4, 300))
+    pads = torch.randint(0,100, (4,300))
+    repeats = torch.randint(0,20, (4,300))
+    loss = model(inps, pads, repeats)
+    print(loss.shape)
\ No newline at end of file
diff --git a/codes/train.py b/codes/train.py
index a3c4edd2..be9cabac 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -299,7 +299,7 @@ class Trainer:
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_diffusion_tts_experimental_fp16/train_diffusion_tts.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
     args = parser.parse_args()