From 4b2946e581ace6f5bd845e0e3ae11c06badbe1fd Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 12 Aug 2021 15:51:23 -0600 Subject: [PATCH] More fix --- codes/data/audio/nv_tacotron_dataset.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index bbe22938..59e2ea9a 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -105,6 +105,8 @@ class TextMelLoader(torch.utils.data.Dataset): def __getitem__(self, index): t, m, p = self.get_mel_text_pair(self.audiopaths_and_text[index]) + orig_output = m.shape[-1] + orig_text_len = t.shape[0] mel_oversize = self.max_mel_len is not None and m.shape[-1] > self.max_mel_len text_oversize = self.max_text_len is not None and t.shape[0] > self.max_text_len if mel_oversize or text_oversize: @@ -117,6 +119,13 @@ class TextMelLoader(torch.utils.data.Dataset): m = F.pad(m, (0, self.max_mel_len - m.shape[-1])) if t.shape[0] != self.max_text_len: t = F.pad(t, (0, self.max_text_len - t.shape[0])) + return { + 'padded_text': t, + 'input_lengths': torch.tensor(orig_text_len, dtype=torch.long), + 'padded_mel': m, + 'output_lengths': torch.tensor(orig_output, dtype=torch.long), + 'filenames': [p] + } return t, m, p def __len__(self):