This commit is contained in:
James Betker 2021-12-25 15:28:59 -07:00
parent 736c2626ee
commit 746392f35c

View File

@ -6,6 +6,7 @@ import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
from munch import munchify
from tokenizers import Tokenizer
from tqdm import tqdm
from transformers import GPT2TokenizerFast
@ -49,7 +50,7 @@ def load_voxpopuli(filename):
class CharacterTokenizer:
def encode(self, txt):
return text_to_sequence(txt, ['english_cleaners'])
return munchify({'ids': text_to_sequence(txt, ['english_cleaners'])})
def decode(self, seq):
return sequence_to_text(seq)
@ -95,7 +96,8 @@ class TextWavLoader(torch.utils.data.Dataset):
self.needs_collate = opt_get(hparams, ['needs_collate'], True)
if not self.needs_collate:
assert self.max_wav_len is not None and self.max_text_len is not None
if opt_get(hparams, ['use_bpe_tokenizer'], True):
self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], True)
if self.use_bpe_tokenizer:
self.tokenizer = Tokenizer.from_file(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
else:
self.tokenizer = CharacterTokenizer()
@ -110,10 +112,11 @@ class TextWavLoader(torch.utils.data.Dataset):
def get_text(self, text):
tokens = self.tokenizer.encode(text.strip().lower()).ids
tokens = torch.IntTensor(tokens)
# Assert if any UNK,start,stop tokens encountered.
assert not torch.any(tokens == 0)
if self.use_bpe_tokenizer:
# Assert if any UNK,start tokens encountered.
assert not torch.any(tokens == 1)
assert not torch.any(tokens == 9999)
# The stop token should always be sacred.
assert not torch.any(tokens == 0)
return tokens
def load_conditioning_candidates(self, path):
@ -245,6 +248,7 @@ if __name__ == '__main__':
'max_text_length': 200,
'sample_rate': 22050,
'load_conditioning': False,
'use_bpe_tokenizer': False,
}
from data import create_dataset, create_dataloader