diff --git a/codes/scripts/audio/mel_bin_norm_compute.py b/codes/scripts/audio/mel_bin_norm_compute.py index 663da801..e50fff20 100644 --- a/codes/scripts/audio/mel_bin_norm_compute.py +++ b/codes/scripts/audio/mel_bin_norm_compute.py @@ -12,7 +12,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='D:\\dlas\\options\\train_dvae_audio_clips.yml') parser.add_argument('-key', type=str, help='Key where audio data is stored', default='clip') - parser.add_argument('-num_batches', type=str, help='Number of batches to collect to compute the norm', default=10) + parser.add_argument('-num_batches', type=int, help='Number of batches to collect to compute the norm', default=10) args = parser.parse_args() with open(args.opt, mode='r') as f: diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index a94155e6..e909a31a 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -609,6 +609,11 @@ class TorchMelSpectrogramInjector(Injector): sample_rate=self.sampling_rate, f_min=self.mel_fmin, f_max=self.mel_fmax, n_mels=self.n_mel_channels, norm="slaney") + self.mel_norm_file = opt_get(opt, ['mel_norm_file'], None) + if self.mel_norm_file is not None: + self.mel_norms = torch.load(self.mel_norm_file) + else: + self.mel_norms = None def forward(self, state): inp = state[self.input] @@ -619,6 +624,9 @@ class TorchMelSpectrogramInjector(Injector): mel = self.mel_stft(inp) # Perform dynamic range compression mel = torch.log(torch.clamp(mel, min=1e-5)) + if self.mel_norms is not None: + self.mel_norms = self.mel_norms.to(mel.device) + mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) return {self.output: mel}