Fix DS
This commit is contained in:
parent
736c2626ee
commit
746392f35c
|
@ -6,6 +6,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
from munch import munchify
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
|
@ -49,7 +50,7 @@ def load_voxpopuli(filename):
|
||||||
|
|
||||||
class CharacterTokenizer:
|
class CharacterTokenizer:
|
||||||
def encode(self, txt):
|
def encode(self, txt):
|
||||||
return text_to_sequence(txt, ['english_cleaners'])
|
return munchify({'ids': text_to_sequence(txt, ['english_cleaners'])})
|
||||||
|
|
||||||
def decode(self, seq):
|
def decode(self, seq):
|
||||||
return sequence_to_text(seq)
|
return sequence_to_text(seq)
|
||||||
|
@ -95,7 +96,8 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
self.needs_collate = opt_get(hparams, ['needs_collate'], True)
|
self.needs_collate = opt_get(hparams, ['needs_collate'], True)
|
||||||
if not self.needs_collate:
|
if not self.needs_collate:
|
||||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||||
if opt_get(hparams, ['use_bpe_tokenizer'], True):
|
self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], True)
|
||||||
|
if self.use_bpe_tokenizer:
|
||||||
self.tokenizer = Tokenizer.from_file(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
|
self.tokenizer = Tokenizer.from_file(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
|
||||||
else:
|
else:
|
||||||
self.tokenizer = CharacterTokenizer()
|
self.tokenizer = CharacterTokenizer()
|
||||||
|
@ -110,10 +112,11 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
def get_text(self, text):
|
def get_text(self, text):
|
||||||
tokens = self.tokenizer.encode(text.strip().lower()).ids
|
tokens = self.tokenizer.encode(text.strip().lower()).ids
|
||||||
tokens = torch.IntTensor(tokens)
|
tokens = torch.IntTensor(tokens)
|
||||||
# Assert if any UNK,start,stop tokens encountered.
|
if self.use_bpe_tokenizer:
|
||||||
assert not torch.any(tokens == 0)
|
# Assert if any UNK,start tokens encountered.
|
||||||
assert not torch.any(tokens == 1)
|
assert not torch.any(tokens == 1)
|
||||||
assert not torch.any(tokens == 9999)
|
# The stop token should always be sacred.
|
||||||
|
assert not torch.any(tokens == 0)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def load_conditioning_candidates(self, path):
|
def load_conditioning_candidates(self, path):
|
||||||
|
@ -245,6 +248,7 @@ if __name__ == '__main__':
|
||||||
'max_text_length': 200,
|
'max_text_length': 200,
|
||||||
'sample_rate': 22050,
|
'sample_rate': 22050,
|
||||||
'load_conditioning': False,
|
'load_conditioning': False,
|
||||||
|
'use_bpe_tokenizer': False,
|
||||||
}
|
}
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user