allow variable size crops

This commit is contained in:
James Betker 2022-06-21 19:48:07 -06:00
parent 3330fa2c10
commit 1394213f1e

View File

@ -89,14 +89,20 @@ class TorchMelSpectrogramInjector(Injector):
class RandomAudioCropInjector(Injector): class RandomAudioCropInjector(Injector):
def __init__(self, opt, env): def __init__(self, opt, env):
super().__init__(opt, env) super().__init__(opt, env)
self.crop_sz = opt['crop_size'] if 'crop_size' in opt.keys():
self.min_crop_sz = opt['crop_size']
self.max_crop_sz = self.min_crop_sz
else:
self.min_crop_sz = opt['min_crop_sz']
self.max_crop_sz = opt['max_crop_sz']
self.lengths_key = opt['lengths_key'] self.lengths_key = opt['lengths_key']
def forward(self, state): def forward(self, state):
crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz)
inp = state[self.input] inp = state[self.input]
lens = state[self.lengths_key] lens = state[self.lengths_key]
len = torch.min(lens) len = torch.min(lens)
margin = len - self.crop_sz margin = len - crop_sz
if margin < 0: if margin < 0:
return {self.output: inp} return {self.output: inp}
start = random.randint(0, margin) start = random.randint(0, margin)