diff --git a/codes/data/audio/paired_voice_audio_dataset.py b/codes/data/audio/paired_voice_audio_dataset.py index bb636715..e2ae7f3c 100644 --- a/codes/data/audio/paired_voice_audio_dataset.py +++ b/codes/data/audio/paired_voice_audio_dataset.py @@ -89,10 +89,7 @@ class TextWavLoader(torch.utils.data.Dataset): random.shuffle(self.audiopaths_and_text) self.max_wav_len = opt_get(hparams, ['max_wav_length'], None) self.max_text_len = opt_get(hparams, ['max_text_length'], None) - # If needs_collate=False, all outputs will be aligned and padded at maximum length. - self.needs_collate = opt_get(hparams, ['needs_collate'], True) - if not self.needs_collate: - assert self.max_wav_len is not None and self.max_text_len is not None + assert self.max_wav_len is not None and self.max_text_len is not None self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], True) if self.use_bpe_tokenizer: from data.audio.voice_tokenizer import VoiceBpeTokenizer @@ -137,83 +134,26 @@ class TextWavLoader(torch.utils.data.Dataset): return self[rv] orig_output = wav.shape[-1] orig_text_len = tseq.shape[0] - if not self.needs_collate: - if wav.shape[-1] != self.max_wav_len: - wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1])) - if tseq.shape[0] != self.max_text_len: - tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0])) - res = { - 'real_text': text, - 'padded_text': tseq, - 'text_lengths': torch.tensor(orig_text_len, dtype=torch.long), - 'wav': wav, - 'wav_lengths': torch.tensor(orig_output, dtype=torch.long), - 'filenames': path - } - if self.load_conditioning: - res['conditioning'] = cond - return res - return tseq, wav, path, text, cond + if wav.shape[-1] != self.max_wav_len: + wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1])) + if tseq.shape[0] != self.max_text_len: + tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0])) + res = { + 'real_text': text, + 'padded_text': tseq, + 'text_lengths': torch.tensor(orig_text_len, dtype=torch.long), + 'wav': wav, + 'wav_lengths': torch.tensor(orig_output, dtype=torch.long), + 'filenames': path + } + if self.load_conditioning: + res['conditioning'] = cond + return res def __len__(self): return len(self.audiopaths_and_text) -class TextMelCollate(): - """ Zero-pads model inputs and targets based on number of frames per step - """ - def __call__(self, batch): - """Collate's training batch from normalized text and wav - PARAMS - ------ - batch: [text_normalized, wav, filename, text] - """ - # Right zero-pad all one-hot text sequences to max input length - input_lengths, ids_sorted_decreasing = torch.sort( - torch.LongTensor([len(x[0]) for x in batch]), - dim=0, descending=True) - max_input_len = input_lengths[0] - - text_padded = torch.LongTensor(len(batch), max_input_len) - text_padded.zero_() - filenames = [] - real_text = [] - conds = [] - for i in range(len(ids_sorted_decreasing)): - text = batch[ids_sorted_decreasing[i]][0] - text_padded[i, :text.size(0)] = text - filenames.append(batch[ids_sorted_decreasing[i]][2]) - real_text.append(batch[ids_sorted_decreasing[i]][3]) - c = batch[ids_sorted_decreasing[i]][4] - if c is not None: - conds.append(c) - - # Right zero-pad wav - num_wavs = batch[0][1].size(0) - max_target_len = max([x[1].size(1) for x in batch]) - - # include mel padded and gate padded - wav_padded = torch.FloatTensor(len(batch), num_wavs, max_target_len) - wav_padded.zero_() - output_lengths = torch.LongTensor(len(batch)) - for i in range(len(ids_sorted_decreasing)): - wav = batch[ids_sorted_decreasing[i]][1] - wav_padded[i, :, :wav.size(1)] = wav - output_lengths[i] = wav.size(1) - - res = { - 'padded_text': text_padded, - 'text_lengths': input_lengths, - 'wav': wav_padded, - 'wav_lengths': output_lengths, - 'filenames': filenames, - 'real_text': real_text, - } - if len(conds) > 0: - res['conditioning'] = torch.stack(conds) - return res - - if __name__ == '__main__': batch_sz = 8 params = { @@ -223,7 +163,6 @@ if __name__ == '__main__': 'phase': 'train', 'n_workers': 0, 'batch_size': batch_sz, - 'needs_collate': True, 'max_wav_length': 255995, 'max_text_length': 200, 'sample_rate': 22050,