From e208d9fb8096f5feaffaaf0b9e7442a0d567fec4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 28 Apr 2022 10:09:22 -0600 Subject: [PATCH] gate augmentations with a flag --- codes/trainer/injectors/audio_injectors.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 6393f52e..4ed8d5c8 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -202,6 +202,7 @@ class ReverseUnivnetInjector(Injector): self.univnet = load_univnet_vocoder().cuda() self.mel_input_key = opt['mel'] self.label_output_key = opt['labels'] + self.do_augmentations = opt_get(opt, ['do_aug'], True) def forward(self, state): with torch.no_grad(): @@ -209,6 +210,22 @@ class ReverseUnivnetInjector(Injector): mel = state[self.mel_input_key] decoded_mel = self.univnet.inference(mel)[:,:,:original_audio.shape[-1]] + if self.do_augmentations: + original_audio = original_audio + torch.rand_like(original_audio) * random.random() * .005 + decoded_mel = decoded_mel + torch.rand_like(decoded_mel) * random.random() * .005 + if(random.random() < .5): + original_audio = torchaudio.functional.resample(torchaudio.functional.resample(original_audio, 24000, 10000), 10000, 24000) + if(random.random() < .5): + decoded_mel = torchaudio.functional.resample(torchaudio.functional.resample(decoded_mel, 24000, 10000), 10000, 24000) + if(random.random() < .5): + original_audio = torchaudio.functional.resample(original_audio, 24000, 22000 + random.randint(0,2000)) + if(random.random() < .5): + decoded_mel = torchaudio.functional.resample(decoded_mel, 24000, 22000 + random.randint(0,2000)) + + smallest_dim = min(original_audio.shape[-1], decoded_mel.shape[-1]) + original_audio = original_audio[:,:,:smallest_dim] + decoded_mel = decoded_mel[:,:,:smallest_dim] + labels = (torch.rand(mel.shape[0], 1, 1, device=mel.device) > .5) output = torch.where(labels, original_audio, decoded_mel)