diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py
index 4f178d18..a9769c24 100644
--- a/codes/trainer/injectors/audio_injectors.py
+++ b/codes/trainer/injectors/audio_injectors.py
@@ -69,20 +69,21 @@ class TorchMelSpectrogramInjector(Injector):
             self.mel_norms = None
 
     def forward(self, state):
-        inp = state[self.input]
-        if len(inp.shape) == 3:  # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
-            inp = inp.squeeze(1)
-        assert len(inp.shape) == 2
-        self.mel_stft = self.mel_stft.to(inp.device)
-        mel = self.mel_stft(inp)
-        # Perform dynamic range compression
-        mel = torch.log(torch.clamp(mel, min=1e-5))
-        if self.mel_norms is not None:
-            self.mel_norms = self.mel_norms.to(mel.device)
-            mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
-        if self.true_norm:
-            mel = normalize_mel(mel)
-        return {self.output: mel}
+        with torch.no_grad():
+            inp = state[self.input]
+            if len(inp.shape) == 3:  # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
+                inp = inp.squeeze(1)
+            assert len(inp.shape) == 2
+            self.mel_stft = self.mel_stft.to(inp.device)
+            mel = self.mel_stft(inp)
+            # Perform dynamic range compression
+            mel = torch.log(torch.clamp(mel, min=1e-5))
+            if self.mel_norms is not None:
+                self.mel_norms = self.mel_norms.to(mel.device)
+                mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
+            if self.true_norm:
+                mel = normalize_mel(mel)
+            return {self.output: mel}
 
 
 class RandomAudioCropInjector(Injector):