nv_tacotron_dataset - Allow training on mozilla cv
This commit is contained in:
parent
d0c74278bf
commit
2d3f0cc33c
|
@ -1,5 +1,7 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import audio2numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
@ -12,6 +14,13 @@ from models.tacotron2.text import text_to_sequence
|
|||
from utils.util import opt_get
|
||||
|
||||
|
||||
def load_mozilla_cv(filename):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
components = [line.strip().split('\t') for line in f][1:] # First line is the header
|
||||
filepaths_and_text = [[f'clips/{component[1]}', component[2]] for component in components]
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
class TextMelLoader(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio,text pairs
|
||||
|
@ -20,7 +29,15 @@ class TextMelLoader(torch.utils.data.Dataset):
|
|||
"""
|
||||
def __init__(self, hparams):
|
||||
self.path = os.path.dirname(hparams['path'])
|
||||
self.audiopaths_and_text = load_filepaths_and_text(hparams['path'])
|
||||
fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj')
|
||||
fetcher_fn = None
|
||||
if fetcher_mode == 'lj':
|
||||
fetcher_fn = load_filepaths_and_text
|
||||
elif fetcher_mode == 'mozilla_cv':
|
||||
fetcher_fn = load_mozilla_cv
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
self.audiopaths_and_text = fetcher_fn(hparams['path'])
|
||||
self.text_cleaners = hparams.text_cleaners
|
||||
self.max_wav_value = hparams.max_wav_value
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
|
@ -45,11 +62,18 @@ class TextMelLoader(torch.utils.data.Dataset):
|
|||
|
||||
def get_mel(self, filename):
|
||||
if not self.load_mel_from_disk:
|
||||
audio, sampling_rate = load_wav_to_torch(filename)
|
||||
if filename.endswith('.wav'):
|
||||
audio, sampling_rate = load_wav_to_torch(filename)
|
||||
audio = audio / self.max_wav_value
|
||||
else:
|
||||
audio, sampling_rate = audio2numpy.audio_from_file(filename)
|
||||
audio = torch.tensor(audio)
|
||||
|
||||
if sampling_rate != self.input_sample_rate:
|
||||
raise ValueError(f"Input sampling rate does not match specified rate {self.input_sample_rate}")
|
||||
audio_norm = audio / self.max_wav_value
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
assert sampling_rate > self.input_sample_rate # Upsampling is not a great idea.
|
||||
audio = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.input_sample_rate/sampling_rate, mode='area')
|
||||
audio = (audio.squeeze().clip(-1,1)+1)/2
|
||||
audio_norm = audio.unsqueeze(0)
|
||||
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
||||
if self.input_sample_rate != self.sampling_rate:
|
||||
ratio = self.sampling_rate / self.input_sample_rate
|
||||
|
@ -137,10 +161,11 @@ class TextMelCollate():
|
|||
if __name__ == '__main__':
|
||||
params = {
|
||||
'mode': 'nv_tacotron',
|
||||
'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
|
||||
'path': 'E:\\audio\\MozillaCommonVoice\\en\\test.tsv',
|
||||
'phase': 'train',
|
||||
'n_workers': 1,
|
||||
'n_workers': 0,
|
||||
'batch_size': 32,
|
||||
'fetcher_mode': 'mozilla_cv',
|
||||
#'return_wavs': True,
|
||||
#'input_sample_rate': 22050,
|
||||
#'sampling_rate': 8000
|
||||
|
|
Loading…
Reference in New Issue
Block a user