Allow audio sample rate interpolation for faster training
This commit is contained in:
parent
96e90e7047
commit
49e3b310ea
|
@ -24,6 +24,7 @@ class TextMelLoader(torch.utils.data.Dataset):
|
|||
self.sampling_rate = hparams.sampling_rate
|
||||
self.load_mel_from_disk = hparams.load_mel_from_disk
|
||||
self.return_wavs = hparams.return_wavs
|
||||
self.input_sample_rate = hparams.input_sample_rate
|
||||
assert not (self.load_mel_from_disk and self.return_wavs)
|
||||
self.stft = layers.TacotronSTFT(
|
||||
hparams.filter_length, hparams.hop_length, hparams.win_length,
|
||||
|
@ -43,12 +44,14 @@ class TextMelLoader(torch.utils.data.Dataset):
|
|||
def get_mel(self, filename):
|
||||
if not self.load_mel_from_disk:
|
||||
audio, sampling_rate = load_wav_to_torch(filename)
|
||||
if sampling_rate != self.stft.sampling_rate:
|
||||
raise ValueError("{} {} SR doesn't match target {} SR".format(
|
||||
sampling_rate, self.stft.sampling_rate))
|
||||
if sampling_rate != self.input_sample_rate:
|
||||
raise ValueError(f"Input sampling rate does not match specified rate {self.input_sample_rate}")
|
||||
audio_norm = audio / self.max_wav_value
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
||||
if self.input_sample_rate != self.sampling_rate:
|
||||
ratio = self.sampling_rate / self.input_sample_rate
|
||||
audio_norm = torch.nn.functional.interpolate(audio_norm.unsqueeze(0), scale_factor=ratio, mode='area').squeeze(0)
|
||||
if self.return_wavs:
|
||||
melspec = audio_norm
|
||||
else:
|
||||
|
@ -133,6 +136,8 @@ if __name__ == '__main__':
|
|||
'n_workers': 0,
|
||||
'batch_size': 2,
|
||||
'return_wavs': True,
|
||||
'input_sample_rate': 22050,
|
||||
'sampling_rate': 8000
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ def create_hparams(hparams_string=None, verbose=False):
|
|||
# Audio Parameters #
|
||||
################################
|
||||
max_wav_value=32768.0,
|
||||
input_sample_rate=22050, # When different from sampling_rate, dataset automatically interpolates to sampling_rate
|
||||
sampling_rate=22050,
|
||||
filter_length=1024,
|
||||
hop_length=256,
|
||||
|
|
|
@ -15,10 +15,10 @@ from utils.util import opt_get, checkpoint
|
|||
|
||||
|
||||
class WavDecoder(nn.Module):
|
||||
def __init__(self, dec_channels, K_ms=40, sample_rate=24000, dropout_probability=.1):
|
||||
def __init__(self, dec_channels, K_ms=40, sample_rate=8000, dropout_probability=.1):
|
||||
super().__init__()
|
||||
self.dec_channels = dec_channels
|
||||
self.K = int(sample_rate * (K_ms/1000)) # 960 with the defaults
|
||||
self.K = int(sample_rate * (K_ms/1000))
|
||||
self.clarifier = UNetModel(image_size=self.K,
|
||||
in_channels=1,
|
||||
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
|
||||
|
@ -189,7 +189,7 @@ class WaveTacotron2(nn.Module):
|
|||
if self.mask_padding and output_lengths is not None:
|
||||
mask_fill = outputs[0].shape[-1]
|
||||
mask = ~get_mask_from_lengths(output_lengths, mask_fill)
|
||||
mask = mask.expand(mask.size(0), 2, mask.size(1))
|
||||
mask = mask.unsqueeze(1).repeat(1,2,1)
|
||||
|
||||
outputs[0].data.masked_fill_(mask, 0.0)
|
||||
outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension.
|
||||
|
|
Loading…
Reference in New Issue
Block a user