Support norms
This commit is contained in:
parent
959979086d
commit
63bf135b93
|
@ -12,7 +12,7 @@ if __name__ == '__main__':
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='D:\\dlas\\options\\train_dvae_audio_clips.yml')
|
||||
parser.add_argument('-key', type=str, help='Key where audio data is stored', default='clip')
|
||||
parser.add_argument('-num_batches', type=str, help='Number of batches to collect to compute the norm', default=10)
|
||||
parser.add_argument('-num_batches', type=int, help='Number of batches to collect to compute the norm', default=10)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.opt, mode='r') as f:
|
||||
|
|
|
@ -609,6 +609,11 @@ class TorchMelSpectrogramInjector(Injector):
|
|||
sample_rate=self.sampling_rate, f_min=self.mel_fmin,
|
||||
f_max=self.mel_fmax, n_mels=self.n_mel_channels,
|
||||
norm="slaney")
|
||||
self.mel_norm_file = opt_get(opt, ['mel_norm_file'], None)
|
||||
if self.mel_norm_file is not None:
|
||||
self.mel_norms = torch.load(self.mel_norm_file)
|
||||
else:
|
||||
self.mel_norms = None
|
||||
|
||||
def forward(self, state):
|
||||
inp = state[self.input]
|
||||
|
@ -619,6 +624,9 @@ class TorchMelSpectrogramInjector(Injector):
|
|||
mel = self.mel_stft(inp)
|
||||
# Perform dynamic range compression
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
if self.mel_norms is not None:
|
||||
self.mel_norms = self.mel_norms.to(mel.device)
|
||||
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
|
||||
return {self.output: mel}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user