gate augmentations with a flag
This commit is contained in:
parent
3f67cb2023
commit
e208d9fb80
|
@ -202,6 +202,7 @@ class ReverseUnivnetInjector(Injector):
|
|||
self.univnet = load_univnet_vocoder().cuda()
|
||||
self.mel_input_key = opt['mel']
|
||||
self.label_output_key = opt['labels']
|
||||
self.do_augmentations = opt_get(opt, ['do_aug'], True)
|
||||
|
||||
def forward(self, state):
|
||||
with torch.no_grad():
|
||||
|
@ -209,6 +210,22 @@ class ReverseUnivnetInjector(Injector):
|
|||
mel = state[self.mel_input_key]
|
||||
decoded_mel = self.univnet.inference(mel)[:,:,:original_audio.shape[-1]]
|
||||
|
||||
if self.do_augmentations:
|
||||
original_audio = original_audio + torch.rand_like(original_audio) * random.random() * .005
|
||||
decoded_mel = decoded_mel + torch.rand_like(decoded_mel) * random.random() * .005
|
||||
if(random.random() < .5):
|
||||
original_audio = torchaudio.functional.resample(torchaudio.functional.resample(original_audio, 24000, 10000), 10000, 24000)
|
||||
if(random.random() < .5):
|
||||
decoded_mel = torchaudio.functional.resample(torchaudio.functional.resample(decoded_mel, 24000, 10000), 10000, 24000)
|
||||
if(random.random() < .5):
|
||||
original_audio = torchaudio.functional.resample(original_audio, 24000, 22000 + random.randint(0,2000))
|
||||
if(random.random() < .5):
|
||||
decoded_mel = torchaudio.functional.resample(decoded_mel, 24000, 22000 + random.randint(0,2000))
|
||||
|
||||
smallest_dim = min(original_audio.shape[-1], decoded_mel.shape[-1])
|
||||
original_audio = original_audio[:,:,:smallest_dim]
|
||||
decoded_mel = decoded_mel[:,:,:smallest_dim]
|
||||
|
||||
labels = (torch.rand(mel.shape[0], 1, 1, device=mel.device) > .5)
|
||||
output = torch.where(labels, original_audio, decoded_mel)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user