gate augmentations with a flag

This commit is contained in:
James Betker 2022-04-28 10:09:22 -06:00
parent 3f67cb2023
commit e208d9fb80

View File

@ -202,6 +202,7 @@ class ReverseUnivnetInjector(Injector):
self.univnet = load_univnet_vocoder().cuda()
self.mel_input_key = opt['mel']
self.label_output_key = opt['labels']
self.do_augmentations = opt_get(opt, ['do_aug'], True)
def forward(self, state):
with torch.no_grad():
@ -209,6 +210,22 @@ class ReverseUnivnetInjector(Injector):
mel = state[self.mel_input_key]
decoded_mel = self.univnet.inference(mel)[:,:,:original_audio.shape[-1]]
if self.do_augmentations:
original_audio = original_audio + torch.rand_like(original_audio) * random.random() * .005
decoded_mel = decoded_mel + torch.rand_like(decoded_mel) * random.random() * .005
if(random.random() < .5):
original_audio = torchaudio.functional.resample(torchaudio.functional.resample(original_audio, 24000, 10000), 10000, 24000)
if(random.random() < .5):
decoded_mel = torchaudio.functional.resample(torchaudio.functional.resample(decoded_mel, 24000, 10000), 10000, 24000)
if(random.random() < .5):
original_audio = torchaudio.functional.resample(original_audio, 24000, 22000 + random.randint(0,2000))
if(random.random() < .5):
decoded_mel = torchaudio.functional.resample(decoded_mel, 24000, 22000 + random.randint(0,2000))
smallest_dim = min(original_audio.shape[-1], decoded_mel.shape[-1])
original_audio = original_audio[:,:,:smallest_dim]
decoded_mel = decoded_mel[:,:,:smallest_dim]
labels = (torch.rand(mel.shape[0], 1, 1, device=mel.device) > .5)
output = torch.where(labels, original_audio, decoded_mel)