Fix phoneme tokenizer
This commit is contained in:
parent
51f8c1bced
commit
eb64d18075
|
@ -10,6 +10,7 @@ import torch.nn.functional as F
|
|||
import torch.utils.data
|
||||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
from transformers import Wav2Vec2Processor
|
||||
|
||||
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
|
||||
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
||||
|
@ -72,20 +73,25 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
|
||||
else:
|
||||
self.tokenizer = CharacterTokenizer()
|
||||
self.ipa_phoneme_tokenizer = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft").tokenizer
|
||||
self.ipa_phoneme_tokenizer.do_phonemize = False
|
||||
self.skipped_items = 0 # records how many items are skipped when accessing an index.
|
||||
|
||||
self.load_times = torch.zeros((256,))
|
||||
self.load_ind = 0
|
||||
|
||||
def get_wav_text_pair(self, audiopath_and_text):
|
||||
def get_wav_text_pair(self, audiopath_and_text, is_phonetic):
|
||||
# separate filename and text
|
||||
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
||||
text_seq = self.get_text(text)
|
||||
text_seq = self.get_text(text, is_phonetic)
|
||||
wav = load_audio(audiopath, self.sample_rate)
|
||||
return (text_seq, wav, text, audiopath_and_text[0])
|
||||
|
||||
def get_text(self, text):
|
||||
tokens = self.tokenizer.encode(text)
|
||||
def get_text(self, text, is_phonetic):
|
||||
if is_phonetic:
|
||||
tokens = self.ipa_phoneme_tokenizer.encode(text)
|
||||
else:
|
||||
tokens = self.tokenizer.encode(text)
|
||||
tokens = torch.IntTensor(tokens)
|
||||
if self.use_bpe_tokenizer:
|
||||
# Assert if any UNK,start tokens encountered.
|
||||
|
@ -161,7 +167,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
self.skipped_items += 1
|
||||
apt, type, is_phonetic = self.load_random_line()
|
||||
try:
|
||||
tseq, wav, text, path = self.get_wav_text_pair(apt)
|
||||
tseq, wav, text, path = self.get_wav_text_pair(apt, is_phonetic)
|
||||
if text is None or len(text.strip()) == 0:
|
||||
raise ValueError
|
||||
cond, cond_is_self = load_similar_clips(apt[0], self.conditioning_length, self.sample_rate,
|
||||
|
@ -291,6 +297,7 @@ if __name__ == '__main__':
|
|||
'conditioning_length': 102400,
|
||||
'use_bpe_tokenizer': True,
|
||||
'load_aligned_codes': False,
|
||||
'debug_loading_failures': True,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
|
||||
|
|
0
codes/scripts/audio/prep_music/mt3_transcribe.py
Normal file
0
codes/scripts/audio/prep_music/mt3_transcribe.py
Normal file
|
@ -327,7 +327,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_contrastive_audio.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_waveform_gen3.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user