From eb64d18075832238629663ac70c5cbb09016f17c Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 13 May 2022 17:56:26 -0600 Subject: [PATCH] Fix phoneme tokenizer --- .../audio/fast_paired_dataset_with_phonemes.py | 17 ++++++++++++----- .../scripts/audio/prep_music/mt3_transcribe.py | 0 codes/train.py | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) create mode 100644 codes/scripts/audio/prep_music/mt3_transcribe.py diff --git a/codes/data/audio/fast_paired_dataset_with_phonemes.py b/codes/data/audio/fast_paired_dataset_with_phonemes.py index 95d1004e..54b208b9 100644 --- a/codes/data/audio/fast_paired_dataset_with_phonemes.py +++ b/codes/data/audio/fast_paired_dataset_with_phonemes.py @@ -10,6 +10,7 @@ import torch.nn.functional as F import torch.utils.data import torchaudio from tqdm import tqdm +from transformers import Wav2Vec2Processor from data.audio.paired_voice_audio_dataset import CharacterTokenizer from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips @@ -72,20 +73,25 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json')) else: self.tokenizer = CharacterTokenizer() + self.ipa_phoneme_tokenizer = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft").tokenizer + self.ipa_phoneme_tokenizer.do_phonemize = False self.skipped_items = 0 # records how many items are skipped when accessing an index. self.load_times = torch.zeros((256,)) self.load_ind = 0 - def get_wav_text_pair(self, audiopath_and_text): + def get_wav_text_pair(self, audiopath_and_text, is_phonetic): # separate filename and text audiopath, text = audiopath_and_text[0], audiopath_and_text[1] - text_seq = self.get_text(text) + text_seq = self.get_text(text, is_phonetic) wav = load_audio(audiopath, self.sample_rate) return (text_seq, wav, text, audiopath_and_text[0]) - def get_text(self, text): - tokens = self.tokenizer.encode(text) + def get_text(self, text, is_phonetic): + if is_phonetic: + tokens = self.ipa_phoneme_tokenizer.encode(text) + else: + tokens = self.tokenizer.encode(text) tokens = torch.IntTensor(tokens) if self.use_bpe_tokenizer: # Assert if any UNK,start tokens encountered. @@ -161,7 +167,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): self.skipped_items += 1 apt, type, is_phonetic = self.load_random_line() try: - tseq, wav, text, path = self.get_wav_text_pair(apt) + tseq, wav, text, path = self.get_wav_text_pair(apt, is_phonetic) if text is None or len(text.strip()) == 0: raise ValueError cond, cond_is_self = load_similar_clips(apt[0], self.conditioning_length, self.sample_rate, @@ -291,6 +297,7 @@ if __name__ == '__main__': 'conditioning_length': 102400, 'use_bpe_tokenizer': True, 'load_aligned_codes': False, + 'debug_loading_failures': True, } from data import create_dataset, create_dataloader diff --git a/codes/scripts/audio/prep_music/mt3_transcribe.py b/codes/scripts/audio/prep_music/mt3_transcribe.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/train.py b/codes/train.py index e1377bb4..14039ec0 100644 --- a/codes/train.py +++ b/codes/train.py @@ -327,7 +327,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_contrastive_audio.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_waveform_gen3.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)