no grads for mel injectors

This commit is contained in:
James Betker 2022-05-23 10:34:53 -06:00
parent dc471f5c6d
commit 36dd4eb61f

View File

@ -69,20 +69,21 @@ class TorchMelSpectrogramInjector(Injector):
self.mel_norms = None self.mel_norms = None
def forward(self, state): def forward(self, state):
inp = state[self.input] with torch.no_grad():
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) inp = state[self.input]
inp = inp.squeeze(1) if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
assert len(inp.shape) == 2 inp = inp.squeeze(1)
self.mel_stft = self.mel_stft.to(inp.device) assert len(inp.shape) == 2
mel = self.mel_stft(inp) self.mel_stft = self.mel_stft.to(inp.device)
# Perform dynamic range compression mel = self.mel_stft(inp)
mel = torch.log(torch.clamp(mel, min=1e-5)) # Perform dynamic range compression
if self.mel_norms is not None: mel = torch.log(torch.clamp(mel, min=1e-5))
self.mel_norms = self.mel_norms.to(mel.device) if self.mel_norms is not None:
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) self.mel_norms = self.mel_norms.to(mel.device)
if self.true_norm: mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
mel = normalize_mel(mel) if self.true_norm:
return {self.output: mel} mel = normalize_mel(mel)
return {self.output: mel}
class RandomAudioCropInjector(Injector): class RandomAudioCropInjector(Injector):