diff --git a/codes/scripts/audio/mel_bin_norm_compute.py b/codes/scripts/audio/mel_bin_norm_compute.py
index 663da801..e50fff20 100644
--- a/codes/scripts/audio/mel_bin_norm_compute.py
+++ b/codes/scripts/audio/mel_bin_norm_compute.py
@@ -12,7 +12,7 @@ if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='D:\\dlas\\options\\train_dvae_audio_clips.yml')
     parser.add_argument('-key', type=str, help='Key where audio data is stored', default='clip')
-    parser.add_argument('-num_batches', type=str, help='Number of batches to collect to compute the norm', default=10)
+    parser.add_argument('-num_batches', type=int, help='Number of batches to collect to compute the norm', default=10)
     args = parser.parse_args()
 
     with open(args.opt, mode='r') as f:
diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py
index a94155e6..e909a31a 100644
--- a/codes/trainer/injectors/base_injectors.py
+++ b/codes/trainer/injectors/base_injectors.py
@@ -609,6 +609,11 @@ class TorchMelSpectrogramInjector(Injector):
                                                              sample_rate=self.sampling_rate, f_min=self.mel_fmin,
                                                              f_max=self.mel_fmax, n_mels=self.n_mel_channels,
                                                              norm="slaney")
+        self.mel_norm_file = opt_get(opt, ['mel_norm_file'], None)
+        if self.mel_norm_file is not None:
+            self.mel_norms = torch.load(self.mel_norm_file)
+        else:
+            self.mel_norms = None
 
     def forward(self, state):
         inp = state[self.input]
@@ -619,6 +624,9 @@ class TorchMelSpectrogramInjector(Injector):
         mel = self.mel_stft(inp)
         # Perform dynamic range compression
         mel = torch.log(torch.clamp(mel, min=1e-5))
+        if self.mel_norms is not None:
+            self.mel_norms = self.mel_norms.to(mel.device)
+            mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
         return {self.output: mel}