max/min mel invalid fix

This commit is contained in:
James Betker 2021-08-13 09:36:31 -06:00
parent 4b2946e581
commit fff1a59e08

View File

@ -81,6 +81,9 @@ class TextMelLoader(torch.utils.data.Dataset):
print(f'{filename} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {self.input_sample_rate}. This is not a good idea.') print(f'{filename} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {self.input_sample_rate}. This is not a good idea.')
audio = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.input_sample_rate/sampling_rate, mode='area', recompute_scale_factor=False) audio = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.input_sample_rate/sampling_rate, mode='area', recompute_scale_factor=False)
audio = (audio.squeeze().clip(-1,1)+1)/2 audio = (audio.squeeze().clip(-1,1)+1)/2
if (audio.min() < -1).any() or (audio.max() > 1).any():
print(f"Error with audio ranging for {filename}; min={audio.min()} max={audio.max()}")
return None
audio_norm = audio.unsqueeze(0) audio_norm = audio.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
if self.input_sample_rate != self.sampling_rate: if self.input_sample_rate != self.sampling_rate:
@ -105,15 +108,16 @@ class TextMelLoader(torch.utils.data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
t, m, p = self.get_mel_text_pair(self.audiopaths_and_text[index]) t, m, p = self.get_mel_text_pair(self.audiopaths_and_text[index])
orig_output = m.shape[-1] if m is None or \
orig_text_len = t.shape[0] (self.max_mel_len is not None and m.shape[-1] > self.max_mel_len) or \
mel_oversize = self.max_mel_len is not None and m.shape[-1] > self.max_mel_len (self.max_text_len is not None and t.shape[0] > self.max_text_len):
text_oversize = self.max_text_len is not None and t.shape[0] > self.max_text_len if m is not None:
if mel_oversize or text_oversize: print(f"Exception {index} mel_len:{m.shape[-1]} text_len:{t.shape[0]} fname: {p}")
print(f"Exception {index} mel_len:{m.shape[-1]} text_len:{t.shape[0]} fname: {p}")
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
rv = random.randint(0,len(self)-1) rv = random.randint(0,len(self)-1)
return self[rv] return self[rv]
orig_output = m.shape[-1]
orig_text_len = t.shape[0]
if not self.needs_collate: if not self.needs_collate:
if m.shape[-1] != self.max_mel_len: if m.shape[-1] != self.max_mel_len:
m = F.pad(m, (0, self.max_mel_len - m.shape[-1])) m = F.pad(m, (0, self.max_mel_len - m.shape[-1]))
@ -191,9 +195,9 @@ class TextMelCollate():
if __name__ == '__main__': if __name__ == '__main__':
params = { params = {
'mode': 'nv_tacotron', 'mode': 'nv_tacotron',
'path': 'E:\\audio\\MozillaCommonVoice\\en\\test.tsv', 'path': 'E:\\audio\\MozillaCommonVoice\\en\\train.tsv',
'phase': 'train', 'phase': 'train',
'n_workers': 0, 'n_workers': 12,
'batch_size': 32, 'batch_size': 32,
'fetcher_mode': 'mozilla_cv', 'fetcher_mode': 'mozilla_cv',
'needs_collate': False, 'needs_collate': False,
@ -209,9 +213,10 @@ if __name__ == '__main__':
dl = create_dataloader(ds, params, collate_fn=c) dl = create_dataloader(ds, params, collate_fn=c)
i = 0 i = 0
m = None m = None
for i, b in tqdm(enumerate(dl)): for k in range(1000):
continue for i, b in tqdm(enumerate(dl)):
pm = b['padded_mel'] continue
pm = torch.nn.functional.pad(pm, (0, 800-pm.shape[-1])) pm = b['padded_mel']
m = pm if m is None else torch.cat([m, pm], dim=0) pm = torch.nn.functional.pad(pm, (0, 800-pm.shape[-1]))
print(m.mean(), m.std()) m = pm if m is None else torch.cat([m, pm], dim=0)
print(m.mean(), m.std())