diff --git a/codes/scripts/audio/mel_bin_norm_compute.py b/codes/scripts/audio/mel_bin_norm_compute.py index e50fff20..8155c6de 100644 --- a/codes/scripts/audio/mel_bin_norm_compute.py +++ b/codes/scripts/audio/mel_bin_norm_compute.py @@ -5,14 +5,14 @@ import yaml from tqdm import tqdm from data import create_dataset, create_dataloader -from trainer.injectors.base_injectors import TorchMelSpectrogramInjector +from scripts.audio.gen.speech_synthesis_utils import wav_to_univnet_mel from utils.options import Loader if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='D:\\dlas\\options\\train_dvae_audio_clips.yml') - parser.add_argument('-key', type=str, help='Key where audio data is stored', default='clip') - parser.add_argument('-num_batches', type=int, help='Number of batches to collect to compute the norm', default=10) + parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='D:\\dlas\\options\\train_diffusion_tts9.yml') + parser.add_argument('-key', type=str, help='Key where audio data is stored', default='wav') + parser.add_argument('-num_batches', type=int, help='Number of batches to collect to compute the norm', default=50000) args = parser.parse_args() with open(args.opt, mode='r') as f: @@ -21,14 +21,26 @@ if __name__ == '__main__': dopt['phase'] = 'train' dataset, collate = create_dataset(dopt, return_collate=True) dataloader = create_dataloader(dataset, dopt, collate_fn=collate, shuffle=True) - inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel'},{}).cuda() - mels = [] + mel_means = [] + mel_max = -999999999 + mel_min = 999999999 + mel_stds = [] + mel_vars = [] for batch in tqdm(dataloader): - clip = batch[args.key].cuda() - mel = inj({'wav': clip})['mel'] - mels.append(mel.mean((0,2)).cpu()) - if len(mels) > args.num_batches: + if len(mel_means) > args.num_batches: break - mel_norms = torch.stack(mels).mean(0) - torch.save(mel_norms, 'mel_norms.pth') \ No newline at end of file + clip = batch[args.key].cuda() + for b in range(clip.shape[0]): + wav = clip[b].unsqueeze(0) + wav = wav[:, :, :batch[f'{args.key}_lengths'][b]] + mel = wav_to_univnet_mel(clip) # Caution: make sure this isn't already normed. + mel_means.append(mel.mean((0,2)).cpu()) + mel_max = max(mel.max().item(), mel_max) + mel_min = min(mel.min().item(), mel_min) + mel_stds.append(mel.std((0,2)).cpu()) + mel_vars.append(mel.var((0,2)).cpu()) + mel_means = torch.stack(mel_means).mean(0) + mel_stds = torch.stack(mel_stds).mean(0) + mel_vars = torch.stack(mel_vars).mean(0) + torch.save((mel_means,mel_max,mel_min,mel_stds,mel_vars), 'univnet_mel_norms.pth') \ No newline at end of file diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 694981d0..e255cd3e 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -21,6 +21,14 @@ class MelSpectrogramInjector(Injector): mel_fmax = opt_get(opt, ['mel_fmax'], 8000) sampling_rate = opt_get(opt, ['sampling_rate'], 22050) self.stft = TacotronSTFT(filter_length, hop_length, win_length, n_mel_channels, sampling_rate, mel_fmin, mel_fmax) + self.mel_norm_file = opt_get(opt, ['mel_norm_file'], None) + if self.mel_norm_file is not None: + # Note that this format is different from TorchMelSpectrogramInjector. + mel_means, mel_max, mel_min, mel_stds, mel_vars = torch.load(self.mel_norm_file) + self.mel_max = mel_max + self.mel_min = mel_min + else: + self.mel_norms = None def forward(self, state): inp = state[self.input] @@ -28,7 +36,10 @@ class MelSpectrogramInjector(Injector): inp = inp.squeeze(1) assert len(inp.shape) == 2 self.stft = self.stft.to(inp.device) - return {self.output: self.stft.mel_spectrogram(inp)} + mel = self.stft.mel_spectrogram(inp) + if self.mel_norms is not None: + mel = 2 * ((mel - self.mel_min) / (self.mel_max - self.mel_min)) - 1 + return {self.output: mel} class TorchMelSpectrogramInjector(Injector):