This commit is contained in:
James Betker 2021-12-25 15:28:59 -07:00
parent 736c2626ee
commit 746392f35c

View File

@ -6,6 +6,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.data import torch.utils.data
import torchaudio import torchaudio
from munch import munchify
from tokenizers import Tokenizer from tokenizers import Tokenizer
from tqdm import tqdm from tqdm import tqdm
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
@ -49,7 +50,7 @@ def load_voxpopuli(filename):
class CharacterTokenizer: class CharacterTokenizer:
def encode(self, txt): def encode(self, txt):
return text_to_sequence(txt, ['english_cleaners']) return munchify({'ids': text_to_sequence(txt, ['english_cleaners'])})
def decode(self, seq): def decode(self, seq):
return sequence_to_text(seq) return sequence_to_text(seq)
@ -95,7 +96,8 @@ class TextWavLoader(torch.utils.data.Dataset):
self.needs_collate = opt_get(hparams, ['needs_collate'], True) self.needs_collate = opt_get(hparams, ['needs_collate'], True)
if not self.needs_collate: if not self.needs_collate:
assert self.max_wav_len is not None and self.max_text_len is not None assert self.max_wav_len is not None and self.max_text_len is not None
if opt_get(hparams, ['use_bpe_tokenizer'], True): self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], True)
if self.use_bpe_tokenizer:
self.tokenizer = Tokenizer.from_file(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json')) self.tokenizer = Tokenizer.from_file(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
else: else:
self.tokenizer = CharacterTokenizer() self.tokenizer = CharacterTokenizer()
@ -110,10 +112,11 @@ class TextWavLoader(torch.utils.data.Dataset):
def get_text(self, text): def get_text(self, text):
tokens = self.tokenizer.encode(text.strip().lower()).ids tokens = self.tokenizer.encode(text.strip().lower()).ids
tokens = torch.IntTensor(tokens) tokens = torch.IntTensor(tokens)
# Assert if any UNK,start,stop tokens encountered. if self.use_bpe_tokenizer:
assert not torch.any(tokens == 0) # Assert if any UNK,start tokens encountered.
assert not torch.any(tokens == 1) assert not torch.any(tokens == 1)
assert not torch.any(tokens == 9999) # The stop token should always be sacred.
assert not torch.any(tokens == 0)
return tokens return tokens
def load_conditioning_candidates(self, path): def load_conditioning_candidates(self, path):
@ -245,6 +248,7 @@ if __name__ == '__main__':
'max_text_length': 200, 'max_text_length': 200,
'sample_rate': 22050, 'sample_rate': 22050,
'load_conditioning': False, 'load_conditioning': False,
'use_bpe_tokenizer': False,
} }
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader