fix random_audio_crop injector

This commit is contained in:
James Betker 2022-03-12 20:42:29 -07:00
parent 8f130e2b3f
commit 08599b4c75

View File

@ -74,11 +74,15 @@ class RandomAudioCropInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.crop_sz = opt['crop_size']
self.lengths_key = opt['lengths_key']
def forward(self, state):
inp = state[self.input]
len = inp.shape[-1]
lens = state[self.lengths_key]
len = torch.min(lens)
margin = len - self.crop_sz
if margin < 0:
return {self.output: inp}
start = random.randint(0, margin)
return {self.output: inp[:, :, start:start+self.crop_sz]}