diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 4f178d18..a9769c24 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -69,20 +69,21 @@ class TorchMelSpectrogramInjector(Injector): self.mel_norms = None def forward(self, state): - inp = state[self.input] - if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) - inp = inp.squeeze(1) - assert len(inp.shape) == 2 - self.mel_stft = self.mel_stft.to(inp.device) - mel = self.mel_stft(inp) - # Perform dynamic range compression - mel = torch.log(torch.clamp(mel, min=1e-5)) - if self.mel_norms is not None: - self.mel_norms = self.mel_norms.to(mel.device) - mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) - if self.true_norm: - mel = normalize_mel(mel) - return {self.output: mel} + with torch.no_grad(): + inp = state[self.input] + if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) + inp = inp.squeeze(1) + assert len(inp.shape) == 2 + self.mel_stft = self.mel_stft.to(inp.device) + mel = self.mel_stft(inp) + # Perform dynamic range compression + mel = torch.log(torch.clamp(mel, min=1e-5)) + if self.mel_norms is not None: + self.mel_norms = self.mel_norms.to(mel.device) + mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) + if self.true_norm: + mel = normalize_mel(mel) + return {self.output: mel} class RandomAudioCropInjector(Injector):