From 4c76257c71817485aa14dc903287e2890bb6b7b1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 12 Aug 2021 15:44:55 -0600 Subject: [PATCH] Dont require collation for nv_tacotron --- codes/data/__init__.py | 3 ++- codes/data/audio/nv_tacotron_dataset.py | 27 ++++++++++++++++++++----- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/codes/data/__init__.py b/codes/data/__init__.py index ea4bb4e0..4f0c0c7c 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -69,7 +69,8 @@ def create_dataset(dataset_opt, return_collate=False): default_params = create_hparams() default_params.update(dataset_opt) dataset_opt = munchify(default_params) - collate = C(dataset_opt.n_frames_per_step) + if opt_get(dataset_opt, ['needs_collate'], True): + collate = C(dataset_opt.n_frames_per_step) elif mode == 'gpt_tts': from data.audio.gpt_tts_dataset import GptTtsDataset as D from data.audio.gpt_tts_dataset import GptTtsCollater as C diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index 446517fc..bbe22938 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -5,6 +5,7 @@ import audio2numpy import numpy as np import torch import torch.utils.data +import torch.nn.functional as F from tqdm import tqdm import models.tacotron2.layers as layers @@ -40,7 +41,6 @@ class TextMelLoader(torch.utils.data.Dataset): self.audiopaths_and_text = fetcher_fn(hparams['path']) self.text_cleaners = hparams.text_cleaners self.max_wav_value = hparams.max_wav_value - self.max_mel_len = opt_get(hparams, ['max_mel_length'], None) self.sampling_rate = hparams.sampling_rate self.load_mel_from_disk = hparams.load_mel_from_disk self.return_wavs = opt_get(hparams, ['return_wavs'], False) @@ -52,6 +52,12 @@ class TextMelLoader(torch.utils.data.Dataset): hparams.mel_fmax) random.seed(hparams.seed) random.shuffle(self.audiopaths_and_text) + self.max_mel_len = opt_get(hparams, ['max_mel_length'], None) + self.max_text_len = opt_get(hparams, ['max_text_length'], None) + # If needs_collate=False, all outputs will be aligned and padded at maximum length. + self.needs_collate = opt_get(hparams, ['needs_collate'], True) + if not self.needs_collate: + assert self.max_mel_len is not None and self.max_text_len is not None def get_mel_text_pair(self, audiopath_and_text): # separate filename and text @@ -88,8 +94,8 @@ class TextMelLoader(torch.utils.data.Dataset): else: melspec = torch.from_numpy(np.load(filename)) assert melspec.size(0) == self.stft.n_mel_channels, ( - 'Mel dimension mismatch: given {}, expected {}'.format( - melspec.size(0), self.stft.n_mel_channels)) + 'Mel dimension mismatch: given {}, expected {}'.format(melspec.size(0), self.stft.n_mel_channels)) + return melspec @@ -99,11 +105,18 @@ class TextMelLoader(torch.utils.data.Dataset): def __getitem__(self, index): t, m, p = self.get_mel_text_pair(self.audiopaths_and_text[index]) - if self.max_mel_len is not None and m.shape[-1] > self.max_mel_len: - print(f"Exception {index} mel_len:{m.shape[-1]} fname: {p}") + mel_oversize = self.max_mel_len is not None and m.shape[-1] > self.max_mel_len + text_oversize = self.max_text_len is not None and t.shape[0] > self.max_text_len + if mel_oversize or text_oversize: + print(f"Exception {index} mel_len:{m.shape[-1]} text_len:{t.shape[0]} fname: {p}") # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. rv = random.randint(0,len(self)-1) return self[rv] + if not self.needs_collate: + if m.shape[-1] != self.max_mel_len: + m = F.pad(m, (0, self.max_mel_len - m.shape[-1])) + if t.shape[0] != self.max_text_len: + t = F.pad(t, (0, self.max_text_len - t.shape[0])) return t, m, p def __len__(self): @@ -174,6 +187,9 @@ if __name__ == '__main__': 'n_workers': 0, 'batch_size': 32, 'fetcher_mode': 'mozilla_cv', + 'needs_collate': False, + 'max_mel_length': 800, + 'max_text_length': 200, #'return_wavs': True, #'input_sample_rate': 22050, #'sampling_rate': 8000 @@ -185,6 +201,7 @@ if __name__ == '__main__': i = 0 m = None for i, b in tqdm(enumerate(dl)): + continue pm = b['padded_mel'] pm = torch.nn.functional.pad(pm, (0, 800-pm.shape[-1])) m = pm if m is None else torch.cat([m, pm], dim=0)