no grads for mel injectors

This commit is contained in:
James Betker 2022-05-23 10:34:53 -06:00
parent dc471f5c6d
commit 36dd4eb61f

View File

@ -69,20 +69,21 @@ class TorchMelSpectrogramInjector(Injector):
self.mel_norms = None
def forward(self, state):
inp = state[self.input]
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
inp = inp.squeeze(1)
assert len(inp.shape) == 2
self.mel_stft = self.mel_stft.to(inp.device)
mel = self.mel_stft(inp)
# Perform dynamic range compression
mel = torch.log(torch.clamp(mel, min=1e-5))
if self.mel_norms is not None:
self.mel_norms = self.mel_norms.to(mel.device)
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
if self.true_norm:
mel = normalize_mel(mel)
return {self.output: mel}
with torch.no_grad():
inp = state[self.input]
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
inp = inp.squeeze(1)
assert len(inp.shape) == 2
self.mel_stft = self.mel_stft.to(inp.device)
mel = self.mel_stft(inp)
# Perform dynamic range compression
mel = torch.log(torch.clamp(mel, min=1e-5))
if self.mel_norms is not None:
self.mel_norms = self.mel_norms.to(mel.device)
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
if self.true_norm:
mel = normalize_mel(mel)
return {self.output: mel}
class RandomAudioCropInjector(Injector):