forked from mrq/DL-Art-School
no grads for mel injectors
This commit is contained in:
parent
dc471f5c6d
commit
36dd4eb61f
|
@ -69,20 +69,21 @@ class TorchMelSpectrogramInjector(Injector):
|
|||
self.mel_norms = None
|
||||
|
||||
def forward(self, state):
|
||||
inp = state[self.input]
|
||||
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
|
||||
inp = inp.squeeze(1)
|
||||
assert len(inp.shape) == 2
|
||||
self.mel_stft = self.mel_stft.to(inp.device)
|
||||
mel = self.mel_stft(inp)
|
||||
# Perform dynamic range compression
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
if self.mel_norms is not None:
|
||||
self.mel_norms = self.mel_norms.to(mel.device)
|
||||
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
|
||||
if self.true_norm:
|
||||
mel = normalize_mel(mel)
|
||||
return {self.output: mel}
|
||||
with torch.no_grad():
|
||||
inp = state[self.input]
|
||||
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
|
||||
inp = inp.squeeze(1)
|
||||
assert len(inp.shape) == 2
|
||||
self.mel_stft = self.mel_stft.to(inp.device)
|
||||
mel = self.mel_stft(inp)
|
||||
# Perform dynamic range compression
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
if self.mel_norms is not None:
|
||||
self.mel_norms = self.mel_norms.to(mel.device)
|
||||
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
|
||||
if self.true_norm:
|
||||
mel = normalize_mel(mel)
|
||||
return {self.output: mel}
|
||||
|
||||
|
||||
class RandomAudioCropInjector(Injector):
|
||||
|
|
Loading…
Reference in New Issue
Block a user