Allow audio sample rate interpolation for faster training

This commit is contained in:
James Betker 2021-07-26 17:44:06 -06:00
parent 96e90e7047
commit 49e3b310ea
3 changed files with 12 additions and 6 deletions

View File

@ -24,6 +24,7 @@ class TextMelLoader(torch.utils.data.Dataset):
self.sampling_rate = hparams.sampling_rate self.sampling_rate = hparams.sampling_rate
self.load_mel_from_disk = hparams.load_mel_from_disk self.load_mel_from_disk = hparams.load_mel_from_disk
self.return_wavs = hparams.return_wavs self.return_wavs = hparams.return_wavs
self.input_sample_rate = hparams.input_sample_rate
assert not (self.load_mel_from_disk and self.return_wavs) assert not (self.load_mel_from_disk and self.return_wavs)
self.stft = layers.TacotronSTFT( self.stft = layers.TacotronSTFT(
hparams.filter_length, hparams.hop_length, hparams.win_length, hparams.filter_length, hparams.hop_length, hparams.win_length,
@ -43,12 +44,14 @@ class TextMelLoader(torch.utils.data.Dataset):
def get_mel(self, filename): def get_mel(self, filename):
if not self.load_mel_from_disk: if not self.load_mel_from_disk:
audio, sampling_rate = load_wav_to_torch(filename) audio, sampling_rate = load_wav_to_torch(filename)
if sampling_rate != self.stft.sampling_rate: if sampling_rate != self.input_sample_rate:
raise ValueError("{} {} SR doesn't match target {} SR".format( raise ValueError(f"Input sampling rate does not match specified rate {self.input_sample_rate}")
sampling_rate, self.stft.sampling_rate))
audio_norm = audio / self.max_wav_value audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
if self.input_sample_rate != self.sampling_rate:
ratio = self.sampling_rate / self.input_sample_rate
audio_norm = torch.nn.functional.interpolate(audio_norm.unsqueeze(0), scale_factor=ratio, mode='area').squeeze(0)
if self.return_wavs: if self.return_wavs:
melspec = audio_norm melspec = audio_norm
else: else:
@ -133,6 +136,8 @@ if __name__ == '__main__':
'n_workers': 0, 'n_workers': 0,
'batch_size': 2, 'batch_size': 2,
'return_wavs': True, 'return_wavs': True,
'input_sample_rate': 22050,
'sampling_rate': 8000
} }
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader

View File

@ -33,6 +33,7 @@ def create_hparams(hparams_string=None, verbose=False):
# Audio Parameters # # Audio Parameters #
################################ ################################
max_wav_value=32768.0, max_wav_value=32768.0,
input_sample_rate=22050, # When different from sampling_rate, dataset automatically interpolates to sampling_rate
sampling_rate=22050, sampling_rate=22050,
filter_length=1024, filter_length=1024,
hop_length=256, hop_length=256,

View File

@ -15,10 +15,10 @@ from utils.util import opt_get, checkpoint
class WavDecoder(nn.Module): class WavDecoder(nn.Module):
def __init__(self, dec_channels, K_ms=40, sample_rate=24000, dropout_probability=.1): def __init__(self, dec_channels, K_ms=40, sample_rate=8000, dropout_probability=.1):
super().__init__() super().__init__()
self.dec_channels = dec_channels self.dec_channels = dec_channels
self.K = int(sample_rate * (K_ms/1000)) # 960 with the defaults self.K = int(sample_rate * (K_ms/1000))
self.clarifier = UNetModel(image_size=self.K, self.clarifier = UNetModel(image_size=self.K,
in_channels=1, in_channels=1,
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model. model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
@ -189,7 +189,7 @@ class WaveTacotron2(nn.Module):
if self.mask_padding and output_lengths is not None: if self.mask_padding and output_lengths is not None:
mask_fill = outputs[0].shape[-1] mask_fill = outputs[0].shape[-1]
mask = ~get_mask_from_lengths(output_lengths, mask_fill) mask = ~get_mask_from_lengths(output_lengths, mask_fill)
mask = mask.expand(mask.size(0), 2, mask.size(1)) mask = mask.unsqueeze(1).repeat(1,2,1)
outputs[0].data.masked_fill_(mask, 0.0) outputs[0].data.masked_fill_(mask, 0.0)
outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension. outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension.