forked from mrq/DL-Art-School
Fix phoneme tokenizer
This commit is contained in:
parent
51f8c1bced
commit
eb64d18075
|
@ -10,6 +10,7 @@ import torch.nn.functional as F
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from transformers import Wav2Vec2Processor
|
||||||
|
|
||||||
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
|
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
|
||||||
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
||||||
|
@ -72,20 +73,25 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
|
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
|
||||||
else:
|
else:
|
||||||
self.tokenizer = CharacterTokenizer()
|
self.tokenizer = CharacterTokenizer()
|
||||||
|
self.ipa_phoneme_tokenizer = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft").tokenizer
|
||||||
|
self.ipa_phoneme_tokenizer.do_phonemize = False
|
||||||
self.skipped_items = 0 # records how many items are skipped when accessing an index.
|
self.skipped_items = 0 # records how many items are skipped when accessing an index.
|
||||||
|
|
||||||
self.load_times = torch.zeros((256,))
|
self.load_times = torch.zeros((256,))
|
||||||
self.load_ind = 0
|
self.load_ind = 0
|
||||||
|
|
||||||
def get_wav_text_pair(self, audiopath_and_text):
|
def get_wav_text_pair(self, audiopath_and_text, is_phonetic):
|
||||||
# separate filename and text
|
# separate filename and text
|
||||||
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
||||||
text_seq = self.get_text(text)
|
text_seq = self.get_text(text, is_phonetic)
|
||||||
wav = load_audio(audiopath, self.sample_rate)
|
wav = load_audio(audiopath, self.sample_rate)
|
||||||
return (text_seq, wav, text, audiopath_and_text[0])
|
return (text_seq, wav, text, audiopath_and_text[0])
|
||||||
|
|
||||||
def get_text(self, text):
|
def get_text(self, text, is_phonetic):
|
||||||
tokens = self.tokenizer.encode(text)
|
if is_phonetic:
|
||||||
|
tokens = self.ipa_phoneme_tokenizer.encode(text)
|
||||||
|
else:
|
||||||
|
tokens = self.tokenizer.encode(text)
|
||||||
tokens = torch.IntTensor(tokens)
|
tokens = torch.IntTensor(tokens)
|
||||||
if self.use_bpe_tokenizer:
|
if self.use_bpe_tokenizer:
|
||||||
# Assert if any UNK,start tokens encountered.
|
# Assert if any UNK,start tokens encountered.
|
||||||
|
@ -161,7 +167,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
self.skipped_items += 1
|
self.skipped_items += 1
|
||||||
apt, type, is_phonetic = self.load_random_line()
|
apt, type, is_phonetic = self.load_random_line()
|
||||||
try:
|
try:
|
||||||
tseq, wav, text, path = self.get_wav_text_pair(apt)
|
tseq, wav, text, path = self.get_wav_text_pair(apt, is_phonetic)
|
||||||
if text is None or len(text.strip()) == 0:
|
if text is None or len(text.strip()) == 0:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
cond, cond_is_self = load_similar_clips(apt[0], self.conditioning_length, self.sample_rate,
|
cond, cond_is_self = load_similar_clips(apt[0], self.conditioning_length, self.sample_rate,
|
||||||
|
@ -291,6 +297,7 @@ if __name__ == '__main__':
|
||||||
'conditioning_length': 102400,
|
'conditioning_length': 102400,
|
||||||
'use_bpe_tokenizer': True,
|
'use_bpe_tokenizer': True,
|
||||||
'load_aligned_codes': False,
|
'load_aligned_codes': False,
|
||||||
|
'debug_loading_failures': True,
|
||||||
}
|
}
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
|
|
||||||
|
|
0
codes/scripts/audio/prep_music/mt3_transcribe.py
Normal file
0
codes/scripts/audio/prep_music/mt3_transcribe.py
Normal file
|
@ -327,7 +327,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_contrastive_audio.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_waveform_gen3.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user