forked from mrq/DL-Art-School
nv_tacotron_dataset - Allow training on mozilla cv
This commit is contained in:
parent
d0c74278bf
commit
2d3f0cc33c
|
@ -1,5 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import audio2numpy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
@ -12,6 +14,13 @@ from models.tacotron2.text import text_to_sequence
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
def load_mozilla_cv(filename):
|
||||||
|
with open(filename, encoding='utf-8') as f:
|
||||||
|
components = [line.strip().split('\t') for line in f][1:] # First line is the header
|
||||||
|
filepaths_and_text = [[f'clips/{component[1]}', component[2]] for component in components]
|
||||||
|
return filepaths_and_text
|
||||||
|
|
||||||
|
|
||||||
class TextMelLoader(torch.utils.data.Dataset):
|
class TextMelLoader(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
1) loads audio,text pairs
|
1) loads audio,text pairs
|
||||||
|
@ -20,7 +29,15 @@ class TextMelLoader(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
self.path = os.path.dirname(hparams['path'])
|
self.path = os.path.dirname(hparams['path'])
|
||||||
self.audiopaths_and_text = load_filepaths_and_text(hparams['path'])
|
fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj')
|
||||||
|
fetcher_fn = None
|
||||||
|
if fetcher_mode == 'lj':
|
||||||
|
fetcher_fn = load_filepaths_and_text
|
||||||
|
elif fetcher_mode == 'mozilla_cv':
|
||||||
|
fetcher_fn = load_mozilla_cv
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
self.audiopaths_and_text = fetcher_fn(hparams['path'])
|
||||||
self.text_cleaners = hparams.text_cleaners
|
self.text_cleaners = hparams.text_cleaners
|
||||||
self.max_wav_value = hparams.max_wav_value
|
self.max_wav_value = hparams.max_wav_value
|
||||||
self.sampling_rate = hparams.sampling_rate
|
self.sampling_rate = hparams.sampling_rate
|
||||||
|
@ -45,11 +62,18 @@ class TextMelLoader(torch.utils.data.Dataset):
|
||||||
|
|
||||||
def get_mel(self, filename):
|
def get_mel(self, filename):
|
||||||
if not self.load_mel_from_disk:
|
if not self.load_mel_from_disk:
|
||||||
audio, sampling_rate = load_wav_to_torch(filename)
|
if filename.endswith('.wav'):
|
||||||
|
audio, sampling_rate = load_wav_to_torch(filename)
|
||||||
|
audio = audio / self.max_wav_value
|
||||||
|
else:
|
||||||
|
audio, sampling_rate = audio2numpy.audio_from_file(filename)
|
||||||
|
audio = torch.tensor(audio)
|
||||||
|
|
||||||
if sampling_rate != self.input_sample_rate:
|
if sampling_rate != self.input_sample_rate:
|
||||||
raise ValueError(f"Input sampling rate does not match specified rate {self.input_sample_rate}")
|
assert sampling_rate > self.input_sample_rate # Upsampling is not a great idea.
|
||||||
audio_norm = audio / self.max_wav_value
|
audio = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.input_sample_rate/sampling_rate, mode='area')
|
||||||
audio_norm = audio_norm.unsqueeze(0)
|
audio = (audio.squeeze().clip(-1,1)+1)/2
|
||||||
|
audio_norm = audio.unsqueeze(0)
|
||||||
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
||||||
if self.input_sample_rate != self.sampling_rate:
|
if self.input_sample_rate != self.sampling_rate:
|
||||||
ratio = self.sampling_rate / self.input_sample_rate
|
ratio = self.sampling_rate / self.input_sample_rate
|
||||||
|
@ -137,10 +161,11 @@ class TextMelCollate():
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
params = {
|
params = {
|
||||||
'mode': 'nv_tacotron',
|
'mode': 'nv_tacotron',
|
||||||
'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
|
'path': 'E:\\audio\\MozillaCommonVoice\\en\\test.tsv',
|
||||||
'phase': 'train',
|
'phase': 'train',
|
||||||
'n_workers': 1,
|
'n_workers': 0,
|
||||||
'batch_size': 32,
|
'batch_size': 32,
|
||||||
|
'fetcher_mode': 'mozilla_cv',
|
||||||
#'return_wavs': True,
|
#'return_wavs': True,
|
||||||
#'input_sample_rate': 22050,
|
#'input_sample_rate': 22050,
|
||||||
#'sampling_rate': 8000
|
#'sampling_rate': 8000
|
||||||
|
|
Loading…
Reference in New Issue
Block a user