diff --git a/codes/data/audio/paired_voice_audio_dataset.py b/codes/data/audio/paired_voice_audio_dataset.py index fdc63a3a..3201c967 100644 --- a/codes/data/audio/paired_voice_audio_dataset.py +++ b/codes/data/audio/paired_voice_audio_dataset.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F import torch.utils.data import torchaudio +from munch import munchify from tokenizers import Tokenizer from tqdm import tqdm from transformers import GPT2TokenizerFast @@ -49,7 +50,7 @@ def load_voxpopuli(filename): class CharacterTokenizer: def encode(self, txt): - return text_to_sequence(txt, ['english_cleaners']) + return munchify({'ids': text_to_sequence(txt, ['english_cleaners'])}) def decode(self, seq): return sequence_to_text(seq) @@ -95,7 +96,8 @@ class TextWavLoader(torch.utils.data.Dataset): self.needs_collate = opt_get(hparams, ['needs_collate'], True) if not self.needs_collate: assert self.max_wav_len is not None and self.max_text_len is not None - if opt_get(hparams, ['use_bpe_tokenizer'], True): + self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], True) + if self.use_bpe_tokenizer: self.tokenizer = Tokenizer.from_file(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json')) else: self.tokenizer = CharacterTokenizer() @@ -110,10 +112,11 @@ class TextWavLoader(torch.utils.data.Dataset): def get_text(self, text): tokens = self.tokenizer.encode(text.strip().lower()).ids tokens = torch.IntTensor(tokens) - # Assert if any UNK,start,stop tokens encountered. + if self.use_bpe_tokenizer: + # Assert if any UNK,start tokens encountered. + assert not torch.any(tokens == 1) + # The stop token should always be sacred. assert not torch.any(tokens == 0) - assert not torch.any(tokens == 1) - assert not torch.any(tokens == 9999) return tokens def load_conditioning_candidates(self, path): @@ -245,6 +248,7 @@ if __name__ == '__main__': 'max_text_length': 200, 'sample_rate': 22050, 'load_conditioning': False, + 'use_bpe_tokenizer': False, } from data import create_dataset, create_dataloader