diff --git a/codes/data/audio/paired_voice_audio_dataset.py b/codes/data/audio/paired_voice_audio_dataset.py index 54cda975..8be74604 100644 --- a/codes/data/audio/paired_voice_audio_dataset.py +++ b/codes/data/audio/paired_voice_audio_dataset.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F import torch.utils.data import torchaudio +from tokenizers import Tokenizer from tqdm import tqdm from transformers import GPT2TokenizerFast @@ -84,7 +85,7 @@ class TextWavLoader(torch.utils.data.Dataset): self.needs_collate = opt_get(hparams, ['needs_collate'], True) if not self.needs_collate: assert self.max_wav_len is not None and self.max_text_len is not None - self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') + self.tokenizer = Tokenizer.from_file(opt_get(hparams, ['tokenizer_vocab'], '../experiments/gpt_tts_tokenizer.json')) def get_wav_text_pair(self, audiopath_and_text): # separate filename and text @@ -94,7 +95,11 @@ class TextWavLoader(torch.utils.data.Dataset): return (text_seq, wav, text, audiopath_and_text[0]) def get_text(self, text): - return torch.IntTensor(self.tokenizer(text)['input_ids']) + tokens = self.tokenizer.encode(text).ids + tokens = torch.IntTensor(tokens) + assert not torch.any(tokens == 0) + assert not torch.any(tokens == 9999) + return tokens def load_conditioning_candidates(self, path): candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0] diff --git a/codes/data/audio/voice_tokenizer_builder.py b/codes/data/audio/voice_tokenizer_builder.py new file mode 100644 index 00000000..df8aedb7 --- /dev/null +++ b/codes/data/audio/voice_tokenizer_builder.py @@ -0,0 +1,47 @@ +from tokenizers import Tokenizer +from tokenizers.models import BPE +from tokenizers.pre_tokenizers import Whitespace +from tokenizers.trainers import BpeTrainer + +from data.audio.paired_voice_audio_dataset import load_mozilla_cv, load_voxpopuli, load_tsv +from models.tacotron2.taco_utils import load_filepaths_and_text + + +def build_text_file_from_priors(priors, output): + with open(output, 'w', encoding='utf-8') as out: + for p, fm in priors: + if fm == 'lj' or fm == 'libritts': + fetcher_fn = load_filepaths_and_text + elif fm == 'tsv': + fetcher_fn = load_tsv + elif fm == 'mozilla_cv': + fetcher_fn = load_mozilla_cv + elif fm == 'voxpopuli': + fetcher_fn = load_voxpopuli + else: + raise NotImplementedError() + apt = fetcher_fn(p) + for path, text in apt: + out.write(text + "\n") + out.flush() + + +def train(): + trainer = BpeTrainer(special_tokens=['[STOP]', '[UNK]'], vocab_size=9999) + tokenizer = Tokenizer(BPE(unk_token="[UNK]")) + tokenizer.pre_tokenizer = Whitespace() + tokenizer.train(['all_texts.txt'], trainer) + tokenizer.save('gpt_tts_tokenizer.json') + + +if __name__ == '__main__': + ''' + build_text_file_from_priors([('Y:\\bigasr_dataset\\libritts\\train-all.txt', 'libritts'), + ('Y:\\bigasr_dataset\\libritts\\test-clean_list.txt', 'libritts'), + #('Y:\\bigasr_dataset\\voxpopuli\\audio\\transcribed_data\\en\\asr_en.tsv', 'voxpopuli'), + ('Y:\\bigasr_dataset\\voxpopuli\\audio\\transcribed_data\\en\\asr_train.tsv', 'voxpopuli'), + ('Y:\\clips\\books1-transcribed.tsv', 'tsv'), + ('Y:\\clips\\books2-transcribed.tsv', 'tsv'), + ('Y:\\clips\\podcasts-0-transcribed.tsv', 'tsv')], 'all_texts.txt') + ''' + train() \ No newline at end of file diff --git a/codes/models/gpt_voice/gpt_tts_hf.py b/codes/models/gpt_voice/gpt_tts_hf.py index b804d0d8..54bf6ffe 100644 --- a/codes/models/gpt_voice/gpt_tts_hf.py +++ b/codes/models/gpt_voice/gpt_tts_hf.py @@ -39,16 +39,18 @@ class ConditioningEncoder(nn.Module): class GptTtsHf(nn.Module): - NUMBER_TEXT_TOKENS = 50257 # The number of BPE tokens produced by the HF GPT2Tokenizer - START_TEXT_TOKEN = 50256 + NUMBER_TEXT_TOKENS = 10000 # The number of tokens produced by our bespoke BPE tokenizer. + START_TEXT_TOKEN = 9999 STOP_TEXT_TOKEN = 0 NUMBER_MEL_CODES = 8194 START_MEL_TOKEN = 8192 STOP_MEL_TOKEN = 8193 - def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=100, max_mel_tokens=250, max_conditioning_inputs=3, + def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=80, max_mel_tokens=250, max_conditioning_inputs=3, checkpointing=True, mel_length_compression=1024, max_conditioning_length=60): super().__init__() + + self.max_mel_tokens = max_mel_tokens self.max_symbols_per_phrase = max_symbols_per_phrase self.model_dim = model_dim