diff --git a/codes/models/audio/music/music_gen_fill_gaps.py b/codes/models/audio/music/music_gen_fill_gaps.py index 7b53ca95..14eaf563 100644 --- a/codes/models/audio/music/music_gen_fill_gaps.py +++ b/codes/models/audio/music/music_gen_fill_gaps.py @@ -172,7 +172,7 @@ class MusicGenerator(nn.Module): def do_masking(self, truth): b, c, s = truth.shape mask = torch.ones_like(truth) - if self.frequency_mask_percent_mask > 0: + if self.random() > .5: # Frequency mask cs = random.randint(0, c-10) ce = min(c-1, cs+random.randint(1, int(self.frequency_mask_percent_mask*c)))