max/min mel invalid fix

This commit is contained in:
James Betker 2021-08-13 09:36:31 -06:00
parent 4b2946e581
commit fff1a59e08

View File

@ -81,6 +81,9 @@ class TextMelLoader(torch.utils.data.Dataset):
print(f'{filename} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {self.input_sample_rate}. This is not a good idea.')
audio = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.input_sample_rate/sampling_rate, mode='area', recompute_scale_factor=False)
audio = (audio.squeeze().clip(-1,1)+1)/2
if (audio.min() < -1).any() or (audio.max() > 1).any():
print(f"Error with audio ranging for {filename}; min={audio.min()} max={audio.max()}")
return None
audio_norm = audio.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
if self.input_sample_rate != self.sampling_rate:
@ -105,15 +108,16 @@ class TextMelLoader(torch.utils.data.Dataset):
def __getitem__(self, index):
t, m, p = self.get_mel_text_pair(self.audiopaths_and_text[index])
orig_output = m.shape[-1]
orig_text_len = t.shape[0]
mel_oversize = self.max_mel_len is not None and m.shape[-1] > self.max_mel_len
text_oversize = self.max_text_len is not None and t.shape[0] > self.max_text_len
if mel_oversize or text_oversize:
if m is None or \
(self.max_mel_len is not None and m.shape[-1] > self.max_mel_len) or \
(self.max_text_len is not None and t.shape[0] > self.max_text_len):
if m is not None:
print(f"Exception {index} mel_len:{m.shape[-1]} text_len:{t.shape[0]} fname: {p}")
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
rv = random.randint(0,len(self)-1)
return self[rv]
orig_output = m.shape[-1]
orig_text_len = t.shape[0]
if not self.needs_collate:
if m.shape[-1] != self.max_mel_len:
m = F.pad(m, (0, self.max_mel_len - m.shape[-1]))
@ -191,9 +195,9 @@ class TextMelCollate():
if __name__ == '__main__':
params = {
'mode': 'nv_tacotron',
'path': 'E:\\audio\\MozillaCommonVoice\\en\\test.tsv',
'path': 'E:\\audio\\MozillaCommonVoice\\en\\train.tsv',
'phase': 'train',
'n_workers': 0,
'n_workers': 12,
'batch_size': 32,
'fetcher_mode': 'mozilla_cv',
'needs_collate': False,
@ -209,6 +213,7 @@ if __name__ == '__main__':
dl = create_dataloader(ds, params, collate_fn=c)
i = 0
m = None
for k in range(1000):
for i, b in tqdm(enumerate(dl)):
continue
pm = b['padded_mel']