This commit is contained in:
James Betker 2021-08-12 15:51:23 -06:00
parent 4c76257c71
commit 4b2946e581

View File

@ -105,6 +105,8 @@ class TextMelLoader(torch.utils.data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
t, m, p = self.get_mel_text_pair(self.audiopaths_and_text[index]) t, m, p = self.get_mel_text_pair(self.audiopaths_and_text[index])
orig_output = m.shape[-1]
orig_text_len = t.shape[0]
mel_oversize = self.max_mel_len is not None and m.shape[-1] > self.max_mel_len mel_oversize = self.max_mel_len is not None and m.shape[-1] > self.max_mel_len
text_oversize = self.max_text_len is not None and t.shape[0] > self.max_text_len text_oversize = self.max_text_len is not None and t.shape[0] > self.max_text_len
if mel_oversize or text_oversize: if mel_oversize or text_oversize:
@ -117,6 +119,13 @@ class TextMelLoader(torch.utils.data.Dataset):
m = F.pad(m, (0, self.max_mel_len - m.shape[-1])) m = F.pad(m, (0, self.max_mel_len - m.shape[-1]))
if t.shape[0] != self.max_text_len: if t.shape[0] != self.max_text_len:
t = F.pad(t, (0, self.max_text_len - t.shape[0])) t = F.pad(t, (0, self.max_text_len - t.shape[0]))
return {
'padded_text': t,
'input_lengths': torch.tensor(orig_text_len, dtype=torch.long),
'padded_mel': m,
'output_lengths': torch.tensor(orig_output, dtype=torch.long),
'filenames': [p]
}
return t, m, p return t, m, p
def __len__(self): def __len__(self):