diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 7634effb..3c8e29cc 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -89,14 +89,20 @@ class TorchMelSpectrogramInjector(Injector): class RandomAudioCropInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) - self.crop_sz = opt['crop_size'] + if 'crop_size' in opt.keys(): + self.min_crop_sz = opt['crop_size'] + self.max_crop_sz = self.min_crop_sz + else: + self.min_crop_sz = opt['min_crop_sz'] + self.max_crop_sz = opt['max_crop_sz'] self.lengths_key = opt['lengths_key'] def forward(self, state): + crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz) inp = state[self.input] lens = state[self.lengths_key] len = torch.min(lens) - margin = len - self.crop_sz + margin = len - crop_sz if margin < 0: return {self.output: inp} start = random.randint(0, margin)