Fix phoneme tokenizer

This commit is contained in:
James Betker 2022-05-13 17:56:26 -06:00
parent 51f8c1bced
commit eb64d18075
3 changed files with 13 additions and 6 deletions

View File

@ -10,6 +10,7 @@ import torch.nn.functional as F
import torch.utils.data
import torchaudio
from tqdm import tqdm
from transformers import Wav2Vec2Processor
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
@ -72,20 +73,25 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
else:
self.tokenizer = CharacterTokenizer()
self.ipa_phoneme_tokenizer = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft").tokenizer
self.ipa_phoneme_tokenizer.do_phonemize = False
self.skipped_items = 0 # records how many items are skipped when accessing an index.
self.load_times = torch.zeros((256,))
self.load_ind = 0
def get_wav_text_pair(self, audiopath_and_text):
def get_wav_text_pair(self, audiopath_and_text, is_phonetic):
# separate filename and text
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
text_seq = self.get_text(text)
text_seq = self.get_text(text, is_phonetic)
wav = load_audio(audiopath, self.sample_rate)
return (text_seq, wav, text, audiopath_and_text[0])
def get_text(self, text):
tokens = self.tokenizer.encode(text)
def get_text(self, text, is_phonetic):
if is_phonetic:
tokens = self.ipa_phoneme_tokenizer.encode(text)
else:
tokens = self.tokenizer.encode(text)
tokens = torch.IntTensor(tokens)
if self.use_bpe_tokenizer:
# Assert if any UNK,start tokens encountered.
@ -161,7 +167,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
self.skipped_items += 1
apt, type, is_phonetic = self.load_random_line()
try:
tseq, wav, text, path = self.get_wav_text_pair(apt)
tseq, wav, text, path = self.get_wav_text_pair(apt, is_phonetic)
if text is None or len(text.strip()) == 0:
raise ValueError
cond, cond_is_self = load_similar_clips(apt[0], self.conditioning_length, self.sample_rate,
@ -291,6 +297,7 @@ if __name__ == '__main__':
'conditioning_length': 102400,
'use_bpe_tokenizer': True,
'load_aligned_codes': False,
'debug_loading_failures': True,
}
from data import create_dataset, create_dataloader

View File

@ -327,7 +327,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_contrastive_audio.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_waveform_gen3.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)