forked from mrq/DL-Art-School
no grads for mel injectors
This commit is contained in:
parent
dc471f5c6d
commit
36dd4eb61f
|
@ -69,20 +69,21 @@ class TorchMelSpectrogramInjector(Injector):
|
||||||
self.mel_norms = None
|
self.mel_norms = None
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
inp = state[self.input]
|
with torch.no_grad():
|
||||||
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
|
inp = state[self.input]
|
||||||
inp = inp.squeeze(1)
|
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
|
||||||
assert len(inp.shape) == 2
|
inp = inp.squeeze(1)
|
||||||
self.mel_stft = self.mel_stft.to(inp.device)
|
assert len(inp.shape) == 2
|
||||||
mel = self.mel_stft(inp)
|
self.mel_stft = self.mel_stft.to(inp.device)
|
||||||
# Perform dynamic range compression
|
mel = self.mel_stft(inp)
|
||||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
# Perform dynamic range compression
|
||||||
if self.mel_norms is not None:
|
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||||
self.mel_norms = self.mel_norms.to(mel.device)
|
if self.mel_norms is not None:
|
||||||
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
|
self.mel_norms = self.mel_norms.to(mel.device)
|
||||||
if self.true_norm:
|
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
|
||||||
mel = normalize_mel(mel)
|
if self.true_norm:
|
||||||
return {self.output: mel}
|
mel = normalize_mel(mel)
|
||||||
|
return {self.output: mel}
|
||||||
|
|
||||||
|
|
||||||
class RandomAudioCropInjector(Injector):
|
class RandomAudioCropInjector(Injector):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user