forked from mrq/DL-Art-School
More fixes
This commit is contained in:
parent
191e0130ee
commit
b9de8a8eda
1
codes/data/audio/gpt_tts_tokenizer.json
Normal file
1
codes/data/audio/gpt_tts_tokenizer.json
Normal file
File diff suppressed because one or more lines are too long
|
@ -85,7 +85,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
self.needs_collate = opt_get(hparams, ['needs_collate'], True)
|
self.needs_collate = opt_get(hparams, ['needs_collate'], True)
|
||||||
if not self.needs_collate:
|
if not self.needs_collate:
|
||||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||||
self.tokenizer = Tokenizer.from_file(opt_get(hparams, ['tokenizer_vocab'], '../experiments/gpt_tts_tokenizer.json'))
|
self.tokenizer = Tokenizer.from_file(opt_get(hparams, ['tokenizer_vocab'], '../experiments/custom_lowercase_gptvoice_tokenizer.json'))
|
||||||
|
|
||||||
def get_wav_text_pair(self, audiopath_and_text):
|
def get_wav_text_pair(self, audiopath_and_text):
|
||||||
# separate filename and text
|
# separate filename and text
|
||||||
|
@ -95,7 +95,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
return (text_seq, wav, text, audiopath_and_text[0])
|
return (text_seq, wav, text, audiopath_and_text[0])
|
||||||
|
|
||||||
def get_text(self, text):
|
def get_text(self, text):
|
||||||
tokens = self.tokenizer.encode(text.lower()).ids
|
tokens = self.tokenizer.encode(text.strip().lower()).ids
|
||||||
tokens = torch.IntTensor(tokens)
|
tokens = torch.IntTensor(tokens)
|
||||||
# Assert if any UNK,start,stop tokens encountered.
|
# Assert if any UNK,start,stop tokens encountered.
|
||||||
assert not torch.any(tokens == 0)
|
assert not torch.any(tokens == 0)
|
||||||
|
|
|
@ -35,17 +35,19 @@ def train():
|
||||||
bcd = datasets.load_dataset('bookcorpus', cache_dir='Z:\\huggingface_datasets\\cache')['train']
|
bcd = datasets.load_dataset('bookcorpus', cache_dir='Z:\\huggingface_datasets\\cache')['train']
|
||||||
wkd = datasets.load_dataset('wikipedia', '20200501.en', cache_dir='Z:\\huggingface_datasets\\cache')['train']
|
wkd = datasets.load_dataset('wikipedia', '20200501.en', cache_dir='Z:\\huggingface_datasets\\cache')['train']
|
||||||
|
|
||||||
allowed_characters_re = re.compile(r'^[0-9a-z!@#%_=:;"/, \-\$\^&\*\(\)\+\{\[\]\}\\\.\']+$')
|
allowed_characters_re = re.compile(r'^[0-9a-z!@#%_=:;"/, \-\$\^&\*\(\)\+\{\[\]\}\\\.\'\?—ʼ]+$')
|
||||||
def preprocess_word(word):
|
def preprocess_word(word, report=False):
|
||||||
word = word.lower()
|
word = word.strip().lower()
|
||||||
if not bool(allowed_characters_re.match(word)):
|
if not bool(allowed_characters_re.match(word)):
|
||||||
|
if report and word:
|
||||||
|
print(f"REPORTING: '{word}'")
|
||||||
return ''
|
return ''
|
||||||
return word
|
return word
|
||||||
|
|
||||||
def batch_iterator(batch_size=1000):
|
def batch_iterator(batch_size=1000):
|
||||||
print("Processing ASR texts.")
|
print("Processing ASR texts.")
|
||||||
for i in range(0, len(ttsd), batch_size):
|
for i in range(0, len(ttsd), batch_size):
|
||||||
yield [preprocess_word(t) for t in ttsd[i:i+batch_size]]
|
yield [preprocess_word(t, True) for t in ttsd[i:i+batch_size]]
|
||||||
|
|
||||||
print("Processing bookcorpus.")
|
print("Processing bookcorpus.")
|
||||||
for i in range(0, len(bcd), batch_size):
|
for i in range(0, len(bcd), batch_size):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user