Support norms

This commit is contained in:
James Betker 2021-12-11 08:30:49 -07:00
parent 959979086d
commit 63bf135b93
2 changed files with 9 additions and 1 deletions

View File

@ -12,7 +12,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='D:\\dlas\\options\\train_dvae_audio_clips.yml')
parser.add_argument('-key', type=str, help='Key where audio data is stored', default='clip')
parser.add_argument('-num_batches', type=str, help='Number of batches to collect to compute the norm', default=10)
parser.add_argument('-num_batches', type=int, help='Number of batches to collect to compute the norm', default=10)
args = parser.parse_args()
with open(args.opt, mode='r') as f:

View File

@ -609,6 +609,11 @@ class TorchMelSpectrogramInjector(Injector):
sample_rate=self.sampling_rate, f_min=self.mel_fmin,
f_max=self.mel_fmax, n_mels=self.n_mel_channels,
norm="slaney")
self.mel_norm_file = opt_get(opt, ['mel_norm_file'], None)
if self.mel_norm_file is not None:
self.mel_norms = torch.load(self.mel_norm_file)
else:
self.mel_norms = None
def forward(self, state):
inp = state[self.input]
@ -619,6 +624,9 @@ class TorchMelSpectrogramInjector(Injector):
mel = self.mel_stft(inp)
# Perform dynamic range compression
mel = torch.log(torch.clamp(mel, min=1e-5))
if self.mel_norms is not None:
self.mel_norms = self.mel_norms.to(mel.device)
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
return {self.output: mel}