forked from mrq/DL-Art-School
Mods to support a autoregressive CTC code generator
This commit is contained in:
parent
8132766d38
commit
4249681c4b
|
@ -1,8 +1,8 @@
|
|||
import hashlib
|
||||
import os
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from itertools import groupby
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -12,8 +12,6 @@ from tqdm import tqdm
|
|||
|
||||
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
|
||||
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
||||
from models.tacotron2.taco_utils import load_filepaths_and_text
|
||||
from models.tacotron2.text import text_to_sequence, sequence_to_text
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
@ -53,6 +51,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
|
||||
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
|
||||
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
|
||||
self.produce_ctc_metadata = opt_get(hparams, ['produce_ctc_metadata'], False)
|
||||
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
|
||||
self.text_cleaners = hparams.text_cleaners
|
||||
self.sample_rate = hparams.sample_rate
|
||||
|
@ -114,6 +113,39 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
print(f"error parsing random offset: {sys.exc_info()}")
|
||||
return self.load_random_line(depth=depth+1) # On failure, just recurse and try again.
|
||||
|
||||
def get_ctc_metadata(self, codes):
|
||||
grouped = groupby(codes.tolist())
|
||||
codes, repeats, pads = [], [], [0]
|
||||
for val, group in grouped:
|
||||
if val == 0:
|
||||
pads[-1] = len(list(group))
|
||||
else:
|
||||
codes.append(val)
|
||||
repeats.append(len(list(group)))
|
||||
pads.append(0)
|
||||
|
||||
codes = torch.tensor(codes)
|
||||
# These clip values are sane maximum values which I did not see in the datasets I have access to.
|
||||
repeats = torch.clip(torch.tensor(repeats), max=30)
|
||||
pads = torch.clip(torch.tensor(pads[:-1]), max=120)
|
||||
|
||||
# Pad or clip the codes to get them to exactly self.max_text_len
|
||||
orig_lens = codes.shape[0]
|
||||
if codes.shape[0] < self.max_text_len:
|
||||
gap = self.max_text_len - codes.shape[0]
|
||||
codes = F.pad(codes, (0, gap))
|
||||
repeats = F.pad(repeats, (0, gap))
|
||||
pads = F.pad(pads, (0, gap))
|
||||
elif codes.shape[0] > self.max_text_len:
|
||||
codes = codes[:self.max_text_len]
|
||||
repeats = codes[:self.max_text_len]
|
||||
pads = pads[:self.max_text_len]
|
||||
return {
|
||||
'ctc_raw_codes': codes,
|
||||
'ctc_pads': pads,
|
||||
'ctc_repeats': repeats,
|
||||
'ctc_raw_lengths': orig_lens,
|
||||
}
|
||||
|
||||
def __getitem__(self, index):
|
||||
self.skipped_items += 1
|
||||
|
@ -130,7 +162,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
if self.debug_failures:
|
||||
print(f"error loading {apt[0]} {sys.exc_info()}")
|
||||
return self[(index+1) % len(self)]
|
||||
aligned_codes = apt[2]
|
||||
raw_codes = apt[2]
|
||||
aligned_codes = raw_codes
|
||||
|
||||
actually_skipped_items = self.skipped_items
|
||||
self.skipped_items = 0
|
||||
|
@ -166,6 +199,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
if self.load_conditioning:
|
||||
res['conditioning'] = cond
|
||||
res['conditioning_contains_self'] = cond_is_self
|
||||
if self.produce_ctc_metadata:
|
||||
res.update(self.get_ctc_metadata(raw_codes))
|
||||
return res
|
||||
|
||||
def __len__(self):
|
||||
|
@ -223,6 +258,7 @@ if __name__ == '__main__':
|
|||
'conditioning_length': 44000,
|
||||
'use_bpe_tokenizer': False,
|
||||
'load_aligned_codes': True,
|
||||
'produce_ctc_metadata': True,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
|
||||
|
@ -236,10 +272,14 @@ if __name__ == '__main__':
|
|||
dl = create_dataloader(ds, params, collate_fn=c)
|
||||
i = 0
|
||||
m = None
|
||||
max_pads, max_repeats = 0, 0
|
||||
for i, b in tqdm(enumerate(dl)):
|
||||
for ib in range(batch_sz):
|
||||
max_pads = max(max_pads, b['ctc_pads'].max())
|
||||
max_repeats = max(max_repeats, b['ctc_repeats'].max())
|
||||
print(f'{i} {ib} {b["real_text"][ib]}')
|
||||
save(b, i, ib, 'wav')
|
||||
if i > 5:
|
||||
break
|
||||
#save(b, i, ib, 'wav')
|
||||
#if i > 5:
|
||||
# break
|
||||
print(max_pads, max_repeats)
|
||||
|
||||
|
|
87
codes/models/gpt_voice/ctc_code_generator.py
Normal file
87
codes/models/gpt_voice/ctc_code_generator.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from x_transformers import Encoder, XTransformer
|
||||
|
||||
from models.gpt_voice.unet_diffusion_tts6 import CheckpointedLayer
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class CheckpointedXTransformerEncoder(nn.Module):
|
||||
"""
|
||||
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
||||
to channels-last that XTransformer expects.
|
||||
"""
|
||||
def __init__(self, **xtransformer_kwargs):
|
||||
super().__init__()
|
||||
self.transformer = XTransformer(**xtransformer_kwargs)
|
||||
|
||||
for xform in [self.transformer.encoder, self.transformer.decoder.net]:
|
||||
for i in range(len(xform.attn_layers.layers)):
|
||||
n, b, r = xform.attn_layers.layers[i]
|
||||
xform.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.transformer(*args, **kwargs)
|
||||
|
||||
|
||||
class CtcCodeGenerator(nn.Module):
|
||||
def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=120, max_repeat=30):
|
||||
super().__init__()
|
||||
self.max_pad = max_pad
|
||||
self.max_repeat = max_repeat
|
||||
self.transformer = XTransformer(
|
||||
dim=model_dim,
|
||||
enc_depth=layers,
|
||||
dec_depth=layers,
|
||||
enc_heads=num_heads,
|
||||
dec_heads=num_heads,
|
||||
enc_num_tokens=ctc_codes,
|
||||
dec_num_tokens=(max_pad+1)*(max_repeat+1),
|
||||
enc_max_seq_len=-1,
|
||||
dec_max_seq_len=-1,
|
||||
|
||||
enc_ff_dropout=dropout,
|
||||
enc_attn_dropout=dropout,
|
||||
enc_use_rmsnorm=True,
|
||||
enc_ff_glu=True,
|
||||
enc_rotary_pos_emb=True,
|
||||
dec_ff_dropout=dropout,
|
||||
dec_attn_dropout=dropout,
|
||||
dec_use_rmsnorm=True,
|
||||
dec_ff_glu=True,
|
||||
dec_rotary_pos_emb=True)
|
||||
|
||||
def forward(self, codes, pads, repeats, unpadded_lengths=None):
|
||||
if unpadded_lengths is not None:
|
||||
max_len = unpadded_lengths.max()
|
||||
codes = codes[:, :max_len]
|
||||
pads = pads[:, :max_len]
|
||||
repeats = repeats[:, :max_len]
|
||||
|
||||
if pads.max() > self.max_pad:
|
||||
print(f"Got unexpectedly long pads. Max: {pads.max()}, {pads}")
|
||||
pads = torch.clip(pads, 0, self.max_pad)
|
||||
if repeats.max() > self.max_repeat:
|
||||
print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}")
|
||||
repeats = torch.clip(repeats, 0, self.max_repeat)
|
||||
assert codes.max() < 36, codes.max()
|
||||
|
||||
labels = pads + repeats * self.max_pad
|
||||
loss = self.transformer(codes, labels)
|
||||
return loss
|
||||
|
||||
|
||||
@register_model
|
||||
def register_ctc_code_generator(opt_net, opt):
|
||||
return CtcCodeGenerator(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = CtcCodeGenerator()
|
||||
inps = torch.randint(0,36, (4, 300))
|
||||
pads = torch.randint(0,100, (4,300))
|
||||
repeats = torch.randint(0,20, (4,300))
|
||||
loss = model(inps, pads, repeats)
|
||||
print(loss.shape)
|
|
@ -299,7 +299,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_diffusion_tts_experimental_fp16/train_diffusion_tts.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user