forked from mrq/DL-Art-School
allow variable size crops
This commit is contained in:
parent
3330fa2c10
commit
1394213f1e
|
@ -89,14 +89,20 @@ class TorchMelSpectrogramInjector(Injector):
|
||||||
class RandomAudioCropInjector(Injector):
|
class RandomAudioCropInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super().__init__(opt, env)
|
super().__init__(opt, env)
|
||||||
self.crop_sz = opt['crop_size']
|
if 'crop_size' in opt.keys():
|
||||||
|
self.min_crop_sz = opt['crop_size']
|
||||||
|
self.max_crop_sz = self.min_crop_sz
|
||||||
|
else:
|
||||||
|
self.min_crop_sz = opt['min_crop_sz']
|
||||||
|
self.max_crop_sz = opt['max_crop_sz']
|
||||||
self.lengths_key = opt['lengths_key']
|
self.lengths_key = opt['lengths_key']
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
|
crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz)
|
||||||
inp = state[self.input]
|
inp = state[self.input]
|
||||||
lens = state[self.lengths_key]
|
lens = state[self.lengths_key]
|
||||||
len = torch.min(lens)
|
len = torch.min(lens)
|
||||||
margin = len - self.crop_sz
|
margin = len - crop_sz
|
||||||
if margin < 0:
|
if margin < 0:
|
||||||
return {self.output: inp}
|
return {self.output: inp}
|
||||||
start = random.randint(0, margin)
|
start = random.randint(0, margin)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user