Mods to support a autoregressive CTC code generator

This commit is contained in:
James Betker 2022-02-03 19:58:54 -07:00
parent 8132766d38
commit 4249681c4b
3 changed files with 135 additions and 8 deletions

View File

@ -1,8 +1,8 @@
import hashlib
import os
import os
import random
import sys
from itertools import groupby
import torch
import torch.nn.functional as F
@ -12,8 +12,6 @@ from tqdm import tqdm
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
from models.tacotron2.taco_utils import load_filepaths_and_text
from models.tacotron2.text import text_to_sequence, sequence_to_text
from utils.util import opt_get
@ -53,6 +51,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
self.produce_ctc_metadata = opt_get(hparams, ['produce_ctc_metadata'], False)
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
self.text_cleaners = hparams.text_cleaners
self.sample_rate = hparams.sample_rate
@ -114,6 +113,39 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
print(f"error parsing random offset: {sys.exc_info()}")
return self.load_random_line(depth=depth+1) # On failure, just recurse and try again.
def get_ctc_metadata(self, codes):
grouped = groupby(codes.tolist())
codes, repeats, pads = [], [], [0]
for val, group in grouped:
if val == 0:
pads[-1] = len(list(group))
else:
codes.append(val)
repeats.append(len(list(group)))
pads.append(0)
codes = torch.tensor(codes)
# These clip values are sane maximum values which I did not see in the datasets I have access to.
repeats = torch.clip(torch.tensor(repeats), max=30)
pads = torch.clip(torch.tensor(pads[:-1]), max=120)
# Pad or clip the codes to get them to exactly self.max_text_len
orig_lens = codes.shape[0]
if codes.shape[0] < self.max_text_len:
gap = self.max_text_len - codes.shape[0]
codes = F.pad(codes, (0, gap))
repeats = F.pad(repeats, (0, gap))
pads = F.pad(pads, (0, gap))
elif codes.shape[0] > self.max_text_len:
codes = codes[:self.max_text_len]
repeats = codes[:self.max_text_len]
pads = pads[:self.max_text_len]
return {
'ctc_raw_codes': codes,
'ctc_pads': pads,
'ctc_repeats': repeats,
'ctc_raw_lengths': orig_lens,
}
def __getitem__(self, index):
self.skipped_items += 1
@ -130,7 +162,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
if self.debug_failures:
print(f"error loading {apt[0]} {sys.exc_info()}")
return self[(index+1) % len(self)]
aligned_codes = apt[2]
raw_codes = apt[2]
aligned_codes = raw_codes
actually_skipped_items = self.skipped_items
self.skipped_items = 0
@ -166,6 +199,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
if self.load_conditioning:
res['conditioning'] = cond
res['conditioning_contains_self'] = cond_is_self
if self.produce_ctc_metadata:
res.update(self.get_ctc_metadata(raw_codes))
return res
def __len__(self):
@ -223,6 +258,7 @@ if __name__ == '__main__':
'conditioning_length': 44000,
'use_bpe_tokenizer': False,
'load_aligned_codes': True,
'produce_ctc_metadata': True,
}
from data import create_dataset, create_dataloader
@ -236,10 +272,14 @@ if __name__ == '__main__':
dl = create_dataloader(ds, params, collate_fn=c)
i = 0
m = None
max_pads, max_repeats = 0, 0
for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz):
max_pads = max(max_pads, b['ctc_pads'].max())
max_repeats = max(max_repeats, b['ctc_repeats'].max())
print(f'{i} {ib} {b["real_text"][ib]}')
save(b, i, ib, 'wav')
if i > 5:
break
#save(b, i, ib, 'wav')
#if i > 5:
# break
print(max_pads, max_repeats)

View File

@ -0,0 +1,87 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from x_transformers import Encoder, XTransformer
from models.gpt_voice.unet_diffusion_tts6 import CheckpointedLayer
from trainer.networks import register_model
from utils.util import opt_get
class CheckpointedXTransformerEncoder(nn.Module):
"""
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
to channels-last that XTransformer expects.
"""
def __init__(self, **xtransformer_kwargs):
super().__init__()
self.transformer = XTransformer(**xtransformer_kwargs)
for xform in [self.transformer.encoder, self.transformer.decoder.net]:
for i in range(len(xform.attn_layers.layers)):
n, b, r = xform.attn_layers.layers[i]
xform.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
def forward(self, *args, **kwargs):
return self.transformer(*args, **kwargs)
class CtcCodeGenerator(nn.Module):
def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=120, max_repeat=30):
super().__init__()
self.max_pad = max_pad
self.max_repeat = max_repeat
self.transformer = XTransformer(
dim=model_dim,
enc_depth=layers,
dec_depth=layers,
enc_heads=num_heads,
dec_heads=num_heads,
enc_num_tokens=ctc_codes,
dec_num_tokens=(max_pad+1)*(max_repeat+1),
enc_max_seq_len=-1,
dec_max_seq_len=-1,
enc_ff_dropout=dropout,
enc_attn_dropout=dropout,
enc_use_rmsnorm=True,
enc_ff_glu=True,
enc_rotary_pos_emb=True,
dec_ff_dropout=dropout,
dec_attn_dropout=dropout,
dec_use_rmsnorm=True,
dec_ff_glu=True,
dec_rotary_pos_emb=True)
def forward(self, codes, pads, repeats, unpadded_lengths=None):
if unpadded_lengths is not None:
max_len = unpadded_lengths.max()
codes = codes[:, :max_len]
pads = pads[:, :max_len]
repeats = repeats[:, :max_len]
if pads.max() > self.max_pad:
print(f"Got unexpectedly long pads. Max: {pads.max()}, {pads}")
pads = torch.clip(pads, 0, self.max_pad)
if repeats.max() > self.max_repeat:
print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}")
repeats = torch.clip(repeats, 0, self.max_repeat)
assert codes.max() < 36, codes.max()
labels = pads + repeats * self.max_pad
loss = self.transformer(codes, labels)
return loss
@register_model
def register_ctc_code_generator(opt_net, opt):
return CtcCodeGenerator(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
model = CtcCodeGenerator()
inps = torch.randint(0,36, (4, 300))
pads = torch.randint(0,100, (4,300))
repeats = torch.randint(0,20, (4,300))
loss = model(inps, pads, repeats)
print(loss.shape)

View File

@ -299,7 +299,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_diffusion_tts_experimental_fp16/train_diffusion_tts.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()