diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index ffb1b43a..1f799113 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -74,11 +74,15 @@ class RandomAudioCropInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) self.crop_sz = opt['crop_size'] + self.lengths_key = opt['lengths_key'] def forward(self, state): inp = state[self.input] - len = inp.shape[-1] + lens = state[self.lengths_key] + len = torch.min(lens) margin = len - self.crop_sz + if margin < 0: + return {self.output: inp} start = random.randint(0, margin) return {self.output: inp[:, :, start:start+self.crop_sz]}