GrandConjoinedDataset

This commit is contained in:
James Betker 2021-12-23 14:32:33 -07:00
parent b9de8a8eda
commit e55d949855
8 changed files with 371 additions and 11 deletions

View File

@ -86,6 +86,8 @@ def create_dataset(dataset_opt, return_collate=False):
from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset as D
elif mode == 'unsupervised_audio_with_noise':
from data.audio.audio_with_noise_dataset import AudioWithNoiseDataset as D
elif mode == 'grand_conjoined_voice':
from data.audio.grand_conjoined_dataset import GrandConjoinedDataset as D
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,157 @@
import os
import os
import random
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
from munch import munchify
from tqdm import tqdm
from transformers import GPT2TokenizerFast
from data.audio.unsupervised_audio_dataset import load_audio, UnsupervisedAudioDataset
from data.text.hf_datasets_wrapper import HfDataset
from data.util import find_files_of_type, is_audio_file
from models.tacotron2.taco_utils import load_filepaths_and_text
from models.tacotron2.text import text_to_sequence
from utils.util import opt_get
def build_paired_voice_dataset(args):
from data.audio.paired_voice_audio_dataset import TextWavLoader as D
from models.tacotron2.hparams import create_hparams
default_params = create_hparams()
default_params.update(args)
dataset_opt = munchify(default_params)
return D(dataset_opt)
def clamp(x, minimum, maximum):
return max(minimum, min(x, maximum))
class GrandConjoinedDataset(torch.utils.data.Dataset):
"""
A joint text & speech dataset that joins three separate datasets into a single batch:
1. Unpaired text
2. Unpaired speech
3. Paired speech & text
Supports situations where the underlying data sources for these three elements are differently sized, e.g. you can
have a massive text corpus of 1B elements, a smaller unpaired speech corpus, and a small paired speech<->text corpus.
Performs tokenization at this level, ignoring any tokenization performed by upstream datasets.
"""
def __init__(self, opt):
paired_dataset_args = opt['paired_dataset_args']
unsupervised_audio_args = opt['unsupervised_audio_args']
text_corpus_args = opt['text_corpus_args']
sample_rate = 22050
self.max_paired_audio_length = opt['max_paired_audio_length']
self.max_paired_text_length = opt['max_paired_text_length']
self.max_solo_audio_length = opt['max_solo_audio_length']
self.max_solo_text_length = opt['max_solo_text_length']
self.sample_rate = sample_rate
# Set some sane arguments for all three datasets.
paired_dataset_args['needs_collate'] = False
paired_dataset_args['load_conditioning'] = False
paired_dataset_args['sample_rate'] = sample_rate
paired_dataset_args['max_wav_length'] = self.max_paired_audio_length
paired_dataset_args['max_text_length'] = self.max_paired_text_length
unsupervised_audio_args['sampling_rate'] = sample_rate
unsupervised_audio_args['do_augmentation'] = False
unsupervised_audio_args['resample_clip'] = False
unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length
self.speech_and_text = build_paired_voice_dataset(paired_dataset_args)
self.speech = UnsupervisedAudioDataset(unsupervised_audio_args)
self.text = HfDataset(**text_corpus_args)
def fetch_text_at(self, i):
try:
txt = self.text[i % len(self.text)]['text']
tok = self.speech_and_text.get_text(txt)
padding_required = self.max_solo_text_length - tok.shape[0]
if padding_required < 0:
# Just truncate since there is no conditioning requried.
tok = tok[:self.max_solo_text_length]
elif padding_required > 0:
tok = F.pad(tok, (0, padding_required))
return txt, tok
except:
# This is fully expected: there are a lot of text strings we intentionally do not
# handle (e.g. ones with emojis, or other languages). Just return another one.
return self.fetch_text_at((i+1) % len(self.text))
def __getitem__(self, i):
snt = self.speech_and_text[i % len(self.speech_and_text)]
sp = self.speech[i % len(self.speech)]
txt, txt_tok = self.fetch_text_at(i % len(self.text))
return {
'paired_audio': snt['wav'],
'paired_audio_lengths': snt['wav_lengths'],
'paired_text': snt['real_text'],
'paired_text_tokens': snt['padded_text'],
'paired_file': snt['filenames'],
'speech_audio': sp['clip'],
'speech_lengths': clamp(sp['clip_lengths'], 0, self.max_solo_audio_length),
'speech_file': sp['path'],
'text_text': txt,
'text_tokens': txt_tok,
}
def __len__(self):
return max(len(self.speech), len(self.speech_and_text), len(self.text))
if __name__ == '__main__':
batch_sz = 8
params = {
'mode': 'grand_conjoined_voice',
'phase': 'train',
'n_workers': 0,
'batch_size': batch_sz,
'max_paired_audio_length': 255995,
'max_paired_text_length': 80,
'max_solo_text_length': 330,
'max_solo_audio_length': 300000,
'paired_dataset_args': {
'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'],
'fetcher_mode': ['libritts'],
},
'unsupervised_audio_args': {
'path': ['Z:\\bigasr_dataset\\librispeech\\test_clean'],
'cache_path': 'test_cache_delete_me.pth',
},
'text_corpus_args': {
'corpi': [['bookcorpus', '']],
'cache_path': 'Z:\\huggingface_datasets\\cache',
},
}
from data import create_dataset, create_dataloader
ds = create_dataset(params)
dl = create_dataloader(ds, params)
def save(b, i, ib, key):
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
def decode(b, ib, key):
return ds.speech_and_text.tokenizer.decode(b[key][ib].cpu().numpy())
i = 0
m = None
for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz):
#save(b, i, ib, 'paired_audio')
print(f'Paired text: {b["paired_text"][ib]}')
print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
#save(b, i, ib, 'speech_audio')
print(f'Text: {b["text_text"][ib]}')
print(f'Text decoded: {decode(b, ib, "text_tokens")}')

View File

@ -85,7 +85,7 @@ class TextWavLoader(torch.utils.data.Dataset):
self.needs_collate = opt_get(hparams, ['needs_collate'], True)
if not self.needs_collate:
assert self.max_wav_len is not None and self.max_text_len is not None
self.tokenizer = Tokenizer.from_file(opt_get(hparams, ['tokenizer_vocab'], '../experiments/custom_lowercase_gptvoice_tokenizer.json'))
self.tokenizer = Tokenizer.from_file(opt_get(hparams, ['tokenizer_vocab'], '../experiments/custom_lowercase_gptvoice_tokenizer_r2.json'))
def get_wav_text_pair(self, audiopath_and_text):
# separate filename and text

View File

@ -143,6 +143,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
output = {
'clip': clips[0],
'clip_lengths': audio_norm.shape[-1],
'path': filename,
}
if self.should_resample_clip:

View File

@ -4,6 +4,7 @@ import datasets
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import ByteLevel
from tokenizers.trainers import BpeTrainer
from data.audio.paired_voice_audio_dataset import load_mozilla_cv, load_voxpopuli, load_tsv
@ -33,9 +34,8 @@ def train():
with open('all_texts.txt', 'r', encoding='utf-8') as at:
ttsd = at.readlines()
bcd = datasets.load_dataset('bookcorpus', cache_dir='Z:\\huggingface_datasets\\cache')['train']
wkd = datasets.load_dataset('wikipedia', '20200501.en', cache_dir='Z:\\huggingface_datasets\\cache')['train']
allowed_characters_re = re.compile(r'^[0-9a-z!@#%_=:;"/, \-\$\^&\*\(\)\+\{\[\]\}\\\.\'\?—ʼ]+$')
allowed_characters_re = re.compile(r'^[0-9a-z!@#%_=:;"/, \-\$\^&\*\(\)\+\{\[\]\}\\\.\'\?—ʼ]+$')
def preprocess_word(word, report=False):
word = word.strip().lower()
if not bool(allowed_characters_re.match(word)):
@ -53,14 +53,13 @@ def train():
for i in range(0, len(bcd), batch_size):
yield [preprocess_word(t) for t in bcd[i:i+batch_size]['text']]
print("Processing wikipedia.")
for i in range(0, len(wkd), batch_size):
yield [preprocess_word(t) for t in wkd[i:i+batch_size]['text']]
trainer = BpeTrainer(special_tokens=['[STOP]', '[UNK]'], vocab_size=9999)
trainer = BpeTrainer(special_tokens=['[STOP]', '[UNK]'], vocab_size=9999, continuing_subword_prefix='$$$')
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
tokenizer.train_from_iterator(batch_iterator(), trainer, length=len(ttsd)+len(bcd)+len(wkd))
tokenizer.train_from_iterator(batch_iterator(), trainer, length=len(ttsd)+len(bcd))
print(tokenizer.decode(tokenizer.encode("i was traveling throughhadslfghds the woods in 1235375t137{{}}").ids))
tokenizer.save('gpt_tts_tokenizer.json')
@ -74,4 +73,4 @@ if __name__ == '__main__':
('Y:\\clips\\books2-transcribed.tsv', 'tsv'),
('Y:\\clips\\podcasts-0-transcribed.tsv', 'tsv')], 'all_texts.txt')
'''
train()
train()

View File

@ -0,0 +1,36 @@
from torch.utils.data import Dataset
import datasets
class HfDataset(Dataset):
"""
Simple wrapper for a HuggingFace dataset that can re-map keys if desired.
"""
def __init__(self, corpi, cache_path=None, key_maps=None, dataset_spec_key='train'):
self.hfd = []
for corpus in corpi:
dataset_name, config = corpus
if config == '':
config = None
self.hfd.append(datasets.load_dataset(dataset_name, config, cache_dir=cache_path)[dataset_spec_key])
self.key_maps = key_maps
def __getitem__(self, item):
for dataset in self.hfd:
if item < len(dataset):
val = dataset[item]
if self.key_maps is None:
return val
else:
return {k: val[v] for k, v in self.key_maps.items()}
else:
item -= len(dataset)
raise IndexError()
def __len__(self):
return sum([len(h) for h in self.hfd])
if __name__ == '__main__':
d = HfDataset([['wikipedia', '20200501.en'], ['bookcorpus', '']], dataset_spec_key='train', cache_path='Z:\\huggingface_datasets\\cache')
print(d[5])

View File

@ -0,0 +1,165 @@
import random
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from models.arch_util import AttentionBlock
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
from models.gpt_voice.mini_encoder import AudioMiniEncoder
from models.tacotron2.text import symbols
from trainer.networks import register_model
from utils.util import opt_get
class ConditioningEncoder(nn.Module):
def __init__(self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
do_checkpointing=False):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
def forward(self, x):
h = self.init(x)
h = self.attn(h)
return h[:, :, 0]
class GptTtsHf(nn.Module):
NUMBER_TEXT_TOKENS = 10000 # The number of tokens produced by our bespoke BPE tokenizer.
START_TEXT_TOKEN = 9999
STOP_TEXT_TOKEN = 0
NUMBER_MEL_CODES = 8194
START_MEL_TOKEN = 8192
STOP_MEL_TOKEN = 8193
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=80, max_mel_tokens=250, max_conditioning_inputs=3,
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60):
super().__init__()
self.max_mel_tokens = max_mel_tokens
self.max_symbols_per_phrase = max_symbols_per_phrase
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
self.gpt = GPT2Model(self.gpt_config)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
self.mel_head = nn.Linear(model_dim, self.NUMBER_MEL_CODES)
self.max_conditioning_length = max_conditioning_length
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
return inp, tar
def get_logits(self, text_inputs, cond_input, mel_inputs, get_attns=False):
text_emb = self.text_embedding(text_inputs)
cond = self.conditioning_encoder(cond_input).unsqueeze(1)
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
emb = torch.cat([text_emb, cond, mel_emb], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state
text_logits = self.final_norm(enc[:, :text_emb.shape[1]])
text_logits = self.text_head(text_logits)
text_logits = text_logits.permute(0,2,1)
mel_logits = self.final_norm(enc[:, -mel_emb.shape[1]:])
mel_logits = self.mel_head(mel_logits)
mel_logits = mel_logits.permute(0,2,1)
return text_logits, mel_logits
def forward(self, text_inputs, cond_input, mel_targets, wav_lengths, return_attentions=False):
"""
Forward pass
text_inputs: long tensor, (b,t)
cond_inputs: MEL float tensor, (b,c,80,s)
mel_targets: long tensor, (b,m)
mel_lengths: long tensor, (b,)
"""
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
mel_lengths = wav_lengths // self.mel_length_compression
for b in range(len(mel_lengths)):
if mel_lengths[b] < mel_targets.shape[-1]:
mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN
# Randomly permute the conditioning spectrogram, to destroy any structure present.
cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])]
if cond_input.shape[-1] > self.max_conditioning_length:
cond_input = cond_input[:,:,:self.max_conditioning_length]
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_targets, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN)
text_logits, mel_logits = self.get_logits(text_inputs, cond_input, mel_inputs, get_attns=return_attentions)
if return_attentions:
return mel_logits
loss_text = F.cross_entropy(text_logits, text_targets.long())
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits
def inference(self, text_inputs, cond_input, **hf_generate_kwargs):
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head)
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
text_emb = self.text_embedding(text_inputs)
# Randomly permute the conditioning spectrogram, to destroy any structure present.
cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])]
if cond_input.shape[-1] > self.max_conditioning_length:
cond_input = cond_input[:,:,:self.max_conditioning_length]
cond = self.conditioning_encoder(cond_input).unsqueeze(1)
emb = torch.cat([text_emb, cond], dim=1)
self.inference_model.store_mel_emb(emb)
fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.START_MEL_TOKEN
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
max_length=emb.shape[1]+self.max_mel_tokens, **hf_generate_kwargs)
return gen[:, fake_inputs.shape[1]:]
@register_model
def register_gpt_tts_hf(opt_net, opt):
return GptTtsHf(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
gpt = GptTtsHf(model_dim=1024, heads=16)
l = gpt(torch.randint(high=len(symbols), size=(2,200)),
torch.arange(0, 80, 1, dtype=torch.float).view(1,80,1).repeat(2,1,800),
torch.randint(high=8192, size=(2,250)),
torch.tensor([150*256,195*256]))