forked from mrq/DL-Art-School
Mods to support a autoregressive CTC code generator
This commit is contained in:
parent
8132766d38
commit
4249681c4b
|
@ -1,8 +1,8 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
from itertools import groupby
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -12,8 +12,6 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
|
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
|
||||||
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
||||||
from models.tacotron2.taco_utils import load_filepaths_and_text
|
|
||||||
from models.tacotron2.text import text_to_sequence, sequence_to_text
|
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,6 +51,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
|
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
|
||||||
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
|
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
|
||||||
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
|
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
|
||||||
|
self.produce_ctc_metadata = opt_get(hparams, ['produce_ctc_metadata'], False)
|
||||||
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
|
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
|
||||||
self.text_cleaners = hparams.text_cleaners
|
self.text_cleaners = hparams.text_cleaners
|
||||||
self.sample_rate = hparams.sample_rate
|
self.sample_rate = hparams.sample_rate
|
||||||
|
@ -114,6 +113,39 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
print(f"error parsing random offset: {sys.exc_info()}")
|
print(f"error parsing random offset: {sys.exc_info()}")
|
||||||
return self.load_random_line(depth=depth+1) # On failure, just recurse and try again.
|
return self.load_random_line(depth=depth+1) # On failure, just recurse and try again.
|
||||||
|
|
||||||
|
def get_ctc_metadata(self, codes):
|
||||||
|
grouped = groupby(codes.tolist())
|
||||||
|
codes, repeats, pads = [], [], [0]
|
||||||
|
for val, group in grouped:
|
||||||
|
if val == 0:
|
||||||
|
pads[-1] = len(list(group))
|
||||||
|
else:
|
||||||
|
codes.append(val)
|
||||||
|
repeats.append(len(list(group)))
|
||||||
|
pads.append(0)
|
||||||
|
|
||||||
|
codes = torch.tensor(codes)
|
||||||
|
# These clip values are sane maximum values which I did not see in the datasets I have access to.
|
||||||
|
repeats = torch.clip(torch.tensor(repeats), max=30)
|
||||||
|
pads = torch.clip(torch.tensor(pads[:-1]), max=120)
|
||||||
|
|
||||||
|
# Pad or clip the codes to get them to exactly self.max_text_len
|
||||||
|
orig_lens = codes.shape[0]
|
||||||
|
if codes.shape[0] < self.max_text_len:
|
||||||
|
gap = self.max_text_len - codes.shape[0]
|
||||||
|
codes = F.pad(codes, (0, gap))
|
||||||
|
repeats = F.pad(repeats, (0, gap))
|
||||||
|
pads = F.pad(pads, (0, gap))
|
||||||
|
elif codes.shape[0] > self.max_text_len:
|
||||||
|
codes = codes[:self.max_text_len]
|
||||||
|
repeats = codes[:self.max_text_len]
|
||||||
|
pads = pads[:self.max_text_len]
|
||||||
|
return {
|
||||||
|
'ctc_raw_codes': codes,
|
||||||
|
'ctc_pads': pads,
|
||||||
|
'ctc_repeats': repeats,
|
||||||
|
'ctc_raw_lengths': orig_lens,
|
||||||
|
}
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
self.skipped_items += 1
|
self.skipped_items += 1
|
||||||
|
@ -130,7 +162,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
if self.debug_failures:
|
if self.debug_failures:
|
||||||
print(f"error loading {apt[0]} {sys.exc_info()}")
|
print(f"error loading {apt[0]} {sys.exc_info()}")
|
||||||
return self[(index+1) % len(self)]
|
return self[(index+1) % len(self)]
|
||||||
aligned_codes = apt[2]
|
raw_codes = apt[2]
|
||||||
|
aligned_codes = raw_codes
|
||||||
|
|
||||||
actually_skipped_items = self.skipped_items
|
actually_skipped_items = self.skipped_items
|
||||||
self.skipped_items = 0
|
self.skipped_items = 0
|
||||||
|
@ -166,6 +199,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
if self.load_conditioning:
|
if self.load_conditioning:
|
||||||
res['conditioning'] = cond
|
res['conditioning'] = cond
|
||||||
res['conditioning_contains_self'] = cond_is_self
|
res['conditioning_contains_self'] = cond_is_self
|
||||||
|
if self.produce_ctc_metadata:
|
||||||
|
res.update(self.get_ctc_metadata(raw_codes))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -223,6 +258,7 @@ if __name__ == '__main__':
|
||||||
'conditioning_length': 44000,
|
'conditioning_length': 44000,
|
||||||
'use_bpe_tokenizer': False,
|
'use_bpe_tokenizer': False,
|
||||||
'load_aligned_codes': True,
|
'load_aligned_codes': True,
|
||||||
|
'produce_ctc_metadata': True,
|
||||||
}
|
}
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
|
|
||||||
|
@ -236,10 +272,14 @@ if __name__ == '__main__':
|
||||||
dl = create_dataloader(ds, params, collate_fn=c)
|
dl = create_dataloader(ds, params, collate_fn=c)
|
||||||
i = 0
|
i = 0
|
||||||
m = None
|
m = None
|
||||||
|
max_pads, max_repeats = 0, 0
|
||||||
for i, b in tqdm(enumerate(dl)):
|
for i, b in tqdm(enumerate(dl)):
|
||||||
for ib in range(batch_sz):
|
for ib in range(batch_sz):
|
||||||
|
max_pads = max(max_pads, b['ctc_pads'].max())
|
||||||
|
max_repeats = max(max_repeats, b['ctc_repeats'].max())
|
||||||
print(f'{i} {ib} {b["real_text"][ib]}')
|
print(f'{i} {ib} {b["real_text"][ib]}')
|
||||||
save(b, i, ib, 'wav')
|
#save(b, i, ib, 'wav')
|
||||||
if i > 5:
|
#if i > 5:
|
||||||
break
|
# break
|
||||||
|
print(max_pads, max_repeats)
|
||||||
|
|
||||||
|
|
87
codes/models/gpt_voice/ctc_code_generator.py
Normal file
87
codes/models/gpt_voice/ctc_code_generator.py
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from x_transformers import Encoder, XTransformer
|
||||||
|
|
||||||
|
from models.gpt_voice.unet_diffusion_tts6 import CheckpointedLayer
|
||||||
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointedXTransformerEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
||||||
|
to channels-last that XTransformer expects.
|
||||||
|
"""
|
||||||
|
def __init__(self, **xtransformer_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.transformer = XTransformer(**xtransformer_kwargs)
|
||||||
|
|
||||||
|
for xform in [self.transformer.encoder, self.transformer.decoder.net]:
|
||||||
|
for i in range(len(xform.attn_layers.layers)):
|
||||||
|
n, b, r = xform.attn_layers.layers[i]
|
||||||
|
xform.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.transformer(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CtcCodeGenerator(nn.Module):
|
||||||
|
def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=120, max_repeat=30):
|
||||||
|
super().__init__()
|
||||||
|
self.max_pad = max_pad
|
||||||
|
self.max_repeat = max_repeat
|
||||||
|
self.transformer = XTransformer(
|
||||||
|
dim=model_dim,
|
||||||
|
enc_depth=layers,
|
||||||
|
dec_depth=layers,
|
||||||
|
enc_heads=num_heads,
|
||||||
|
dec_heads=num_heads,
|
||||||
|
enc_num_tokens=ctc_codes,
|
||||||
|
dec_num_tokens=(max_pad+1)*(max_repeat+1),
|
||||||
|
enc_max_seq_len=-1,
|
||||||
|
dec_max_seq_len=-1,
|
||||||
|
|
||||||
|
enc_ff_dropout=dropout,
|
||||||
|
enc_attn_dropout=dropout,
|
||||||
|
enc_use_rmsnorm=True,
|
||||||
|
enc_ff_glu=True,
|
||||||
|
enc_rotary_pos_emb=True,
|
||||||
|
dec_ff_dropout=dropout,
|
||||||
|
dec_attn_dropout=dropout,
|
||||||
|
dec_use_rmsnorm=True,
|
||||||
|
dec_ff_glu=True,
|
||||||
|
dec_rotary_pos_emb=True)
|
||||||
|
|
||||||
|
def forward(self, codes, pads, repeats, unpadded_lengths=None):
|
||||||
|
if unpadded_lengths is not None:
|
||||||
|
max_len = unpadded_lengths.max()
|
||||||
|
codes = codes[:, :max_len]
|
||||||
|
pads = pads[:, :max_len]
|
||||||
|
repeats = repeats[:, :max_len]
|
||||||
|
|
||||||
|
if pads.max() > self.max_pad:
|
||||||
|
print(f"Got unexpectedly long pads. Max: {pads.max()}, {pads}")
|
||||||
|
pads = torch.clip(pads, 0, self.max_pad)
|
||||||
|
if repeats.max() > self.max_repeat:
|
||||||
|
print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}")
|
||||||
|
repeats = torch.clip(repeats, 0, self.max_repeat)
|
||||||
|
assert codes.max() < 36, codes.max()
|
||||||
|
|
||||||
|
labels = pads + repeats * self.max_pad
|
||||||
|
loss = self.transformer(codes, labels)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_ctc_code_generator(opt_net, opt):
|
||||||
|
return CtcCodeGenerator(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = CtcCodeGenerator()
|
||||||
|
inps = torch.randint(0,36, (4, 300))
|
||||||
|
pads = torch.randint(0,100, (4,300))
|
||||||
|
repeats = torch.randint(0,20, (4,300))
|
||||||
|
loss = model(inps, pads, repeats)
|
||||||
|
print(loss.shape)
|
|
@ -299,7 +299,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_diffusion_tts_experimental_fp16/train_diffusion_tts.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user