forked from mrq/DL-Art-School
Allow audio sample rate interpolation for faster training
This commit is contained in:
parent
96e90e7047
commit
49e3b310ea
|
@ -24,6 +24,7 @@ class TextMelLoader(torch.utils.data.Dataset):
|
||||||
self.sampling_rate = hparams.sampling_rate
|
self.sampling_rate = hparams.sampling_rate
|
||||||
self.load_mel_from_disk = hparams.load_mel_from_disk
|
self.load_mel_from_disk = hparams.load_mel_from_disk
|
||||||
self.return_wavs = hparams.return_wavs
|
self.return_wavs = hparams.return_wavs
|
||||||
|
self.input_sample_rate = hparams.input_sample_rate
|
||||||
assert not (self.load_mel_from_disk and self.return_wavs)
|
assert not (self.load_mel_from_disk and self.return_wavs)
|
||||||
self.stft = layers.TacotronSTFT(
|
self.stft = layers.TacotronSTFT(
|
||||||
hparams.filter_length, hparams.hop_length, hparams.win_length,
|
hparams.filter_length, hparams.hop_length, hparams.win_length,
|
||||||
|
@ -43,12 +44,14 @@ class TextMelLoader(torch.utils.data.Dataset):
|
||||||
def get_mel(self, filename):
|
def get_mel(self, filename):
|
||||||
if not self.load_mel_from_disk:
|
if not self.load_mel_from_disk:
|
||||||
audio, sampling_rate = load_wav_to_torch(filename)
|
audio, sampling_rate = load_wav_to_torch(filename)
|
||||||
if sampling_rate != self.stft.sampling_rate:
|
if sampling_rate != self.input_sample_rate:
|
||||||
raise ValueError("{} {} SR doesn't match target {} SR".format(
|
raise ValueError(f"Input sampling rate does not match specified rate {self.input_sample_rate}")
|
||||||
sampling_rate, self.stft.sampling_rate))
|
|
||||||
audio_norm = audio / self.max_wav_value
|
audio_norm = audio / self.max_wav_value
|
||||||
audio_norm = audio_norm.unsqueeze(0)
|
audio_norm = audio_norm.unsqueeze(0)
|
||||||
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
||||||
|
if self.input_sample_rate != self.sampling_rate:
|
||||||
|
ratio = self.sampling_rate / self.input_sample_rate
|
||||||
|
audio_norm = torch.nn.functional.interpolate(audio_norm.unsqueeze(0), scale_factor=ratio, mode='area').squeeze(0)
|
||||||
if self.return_wavs:
|
if self.return_wavs:
|
||||||
melspec = audio_norm
|
melspec = audio_norm
|
||||||
else:
|
else:
|
||||||
|
@ -133,6 +136,8 @@ if __name__ == '__main__':
|
||||||
'n_workers': 0,
|
'n_workers': 0,
|
||||||
'batch_size': 2,
|
'batch_size': 2,
|
||||||
'return_wavs': True,
|
'return_wavs': True,
|
||||||
|
'input_sample_rate': 22050,
|
||||||
|
'sampling_rate': 8000
|
||||||
}
|
}
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ def create_hparams(hparams_string=None, verbose=False):
|
||||||
# Audio Parameters #
|
# Audio Parameters #
|
||||||
################################
|
################################
|
||||||
max_wav_value=32768.0,
|
max_wav_value=32768.0,
|
||||||
|
input_sample_rate=22050, # When different from sampling_rate, dataset automatically interpolates to sampling_rate
|
||||||
sampling_rate=22050,
|
sampling_rate=22050,
|
||||||
filter_length=1024,
|
filter_length=1024,
|
||||||
hop_length=256,
|
hop_length=256,
|
||||||
|
|
|
@ -15,10 +15,10 @@ from utils.util import opt_get, checkpoint
|
||||||
|
|
||||||
|
|
||||||
class WavDecoder(nn.Module):
|
class WavDecoder(nn.Module):
|
||||||
def __init__(self, dec_channels, K_ms=40, sample_rate=24000, dropout_probability=.1):
|
def __init__(self, dec_channels, K_ms=40, sample_rate=8000, dropout_probability=.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dec_channels = dec_channels
|
self.dec_channels = dec_channels
|
||||||
self.K = int(sample_rate * (K_ms/1000)) # 960 with the defaults
|
self.K = int(sample_rate * (K_ms/1000))
|
||||||
self.clarifier = UNetModel(image_size=self.K,
|
self.clarifier = UNetModel(image_size=self.K,
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
|
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
|
||||||
|
@ -189,7 +189,7 @@ class WaveTacotron2(nn.Module):
|
||||||
if self.mask_padding and output_lengths is not None:
|
if self.mask_padding and output_lengths is not None:
|
||||||
mask_fill = outputs[0].shape[-1]
|
mask_fill = outputs[0].shape[-1]
|
||||||
mask = ~get_mask_from_lengths(output_lengths, mask_fill)
|
mask = ~get_mask_from_lengths(output_lengths, mask_fill)
|
||||||
mask = mask.expand(mask.size(0), 2, mask.size(1))
|
mask = mask.unsqueeze(1).repeat(1,2,1)
|
||||||
|
|
||||||
outputs[0].data.masked_fill_(mask, 0.0)
|
outputs[0].data.masked_fill_(mask, 0.0)
|
||||||
outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension.
|
outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user