diff --git a/codes/data/audio/fast_paired_dataset.py b/codes/data/audio/fast_paired_dataset.py index 638ec81a..a3a09059 100644 --- a/codes/data/audio/fast_paired_dataset.py +++ b/codes/data/audio/fast_paired_dataset.py @@ -1,8 +1,8 @@ import hashlib import os -import os import random import sys +from itertools import groupby import torch import torch.nn.functional as F @@ -12,8 +12,6 @@ from tqdm import tqdm from data.audio.paired_voice_audio_dataset import CharacterTokenizer from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips -from models.tacotron2.taco_utils import load_filepaths_and_text -from models.tacotron2.text import text_to_sequence, sequence_to_text from utils.util import opt_get @@ -53,6 +51,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): self.load_conditioning = opt_get(hparams, ['load_conditioning'], False) self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1) self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100) + self.produce_ctc_metadata = opt_get(hparams, ['produce_ctc_metadata'], False) self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False) self.text_cleaners = hparams.text_cleaners self.sample_rate = hparams.sample_rate @@ -114,6 +113,39 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): print(f"error parsing random offset: {sys.exc_info()}") return self.load_random_line(depth=depth+1) # On failure, just recurse and try again. + def get_ctc_metadata(self, codes): + grouped = groupby(codes.tolist()) + codes, repeats, pads = [], [], [0] + for val, group in grouped: + if val == 0: + pads[-1] = len(list(group)) + else: + codes.append(val) + repeats.append(len(list(group))) + pads.append(0) + + codes = torch.tensor(codes) + # These clip values are sane maximum values which I did not see in the datasets I have access to. + repeats = torch.clip(torch.tensor(repeats), max=30) + pads = torch.clip(torch.tensor(pads[:-1]), max=120) + + # Pad or clip the codes to get them to exactly self.max_text_len + orig_lens = codes.shape[0] + if codes.shape[0] < self.max_text_len: + gap = self.max_text_len - codes.shape[0] + codes = F.pad(codes, (0, gap)) + repeats = F.pad(repeats, (0, gap)) + pads = F.pad(pads, (0, gap)) + elif codes.shape[0] > self.max_text_len: + codes = codes[:self.max_text_len] + repeats = codes[:self.max_text_len] + pads = pads[:self.max_text_len] + return { + 'ctc_raw_codes': codes, + 'ctc_pads': pads, + 'ctc_repeats': repeats, + 'ctc_raw_lengths': orig_lens, + } def __getitem__(self, index): self.skipped_items += 1 @@ -130,7 +162,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): if self.debug_failures: print(f"error loading {apt[0]} {sys.exc_info()}") return self[(index+1) % len(self)] - aligned_codes = apt[2] + raw_codes = apt[2] + aligned_codes = raw_codes actually_skipped_items = self.skipped_items self.skipped_items = 0 @@ -166,6 +199,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): if self.load_conditioning: res['conditioning'] = cond res['conditioning_contains_self'] = cond_is_self + if self.produce_ctc_metadata: + res.update(self.get_ctc_metadata(raw_codes)) return res def __len__(self): @@ -223,6 +258,7 @@ if __name__ == '__main__': 'conditioning_length': 44000, 'use_bpe_tokenizer': False, 'load_aligned_codes': True, + 'produce_ctc_metadata': True, } from data import create_dataset, create_dataloader @@ -236,10 +272,14 @@ if __name__ == '__main__': dl = create_dataloader(ds, params, collate_fn=c) i = 0 m = None + max_pads, max_repeats = 0, 0 for i, b in tqdm(enumerate(dl)): for ib in range(batch_sz): + max_pads = max(max_pads, b['ctc_pads'].max()) + max_repeats = max(max_repeats, b['ctc_repeats'].max()) print(f'{i} {ib} {b["real_text"][ib]}') - save(b, i, ib, 'wav') - if i > 5: - break + #save(b, i, ib, 'wav') + #if i > 5: + # break + print(max_pads, max_repeats) diff --git a/codes/models/gpt_voice/ctc_code_generator.py b/codes/models/gpt_voice/ctc_code_generator.py new file mode 100644 index 00000000..5dfce669 --- /dev/null +++ b/codes/models/gpt_voice/ctc_code_generator.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from x_transformers import Encoder, XTransformer + +from models.gpt_voice.unet_diffusion_tts6 import CheckpointedLayer +from trainer.networks import register_model +from utils.util import opt_get + + +class CheckpointedXTransformerEncoder(nn.Module): + """ + Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid + to channels-last that XTransformer expects. + """ + def __init__(self, **xtransformer_kwargs): + super().__init__() + self.transformer = XTransformer(**xtransformer_kwargs) + + for xform in [self.transformer.encoder, self.transformer.decoder.net]: + for i in range(len(xform.attn_layers.layers)): + n, b, r = xform.attn_layers.layers[i] + xform.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) + + def forward(self, *args, **kwargs): + return self.transformer(*args, **kwargs) + + +class CtcCodeGenerator(nn.Module): + def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=120, max_repeat=30): + super().__init__() + self.max_pad = max_pad + self.max_repeat = max_repeat + self.transformer = XTransformer( + dim=model_dim, + enc_depth=layers, + dec_depth=layers, + enc_heads=num_heads, + dec_heads=num_heads, + enc_num_tokens=ctc_codes, + dec_num_tokens=(max_pad+1)*(max_repeat+1), + enc_max_seq_len=-1, + dec_max_seq_len=-1, + + enc_ff_dropout=dropout, + enc_attn_dropout=dropout, + enc_use_rmsnorm=True, + enc_ff_glu=True, + enc_rotary_pos_emb=True, + dec_ff_dropout=dropout, + dec_attn_dropout=dropout, + dec_use_rmsnorm=True, + dec_ff_glu=True, + dec_rotary_pos_emb=True) + + def forward(self, codes, pads, repeats, unpadded_lengths=None): + if unpadded_lengths is not None: + max_len = unpadded_lengths.max() + codes = codes[:, :max_len] + pads = pads[:, :max_len] + repeats = repeats[:, :max_len] + + if pads.max() > self.max_pad: + print(f"Got unexpectedly long pads. Max: {pads.max()}, {pads}") + pads = torch.clip(pads, 0, self.max_pad) + if repeats.max() > self.max_repeat: + print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}") + repeats = torch.clip(repeats, 0, self.max_repeat) + assert codes.max() < 36, codes.max() + + labels = pads + repeats * self.max_pad + loss = self.transformer(codes, labels) + return loss + + +@register_model +def register_ctc_code_generator(opt_net, opt): + return CtcCodeGenerator(**opt_get(opt_net, ['kwargs'], {})) + + +if __name__ == '__main__': + model = CtcCodeGenerator() + inps = torch.randint(0,36, (4, 300)) + pads = torch.randint(0,100, (4,300)) + repeats = torch.randint(0,20, (4,300)) + loss = model(inps, pads, repeats) + print(loss.shape) \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index a3c4edd2..be9cabac 100644 --- a/codes/train.py +++ b/codes/train.py @@ -299,7 +299,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_diffusion_tts_experimental_fp16/train_diffusion_tts.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()