Spleeter mods
This commit is contained in:
parent
0382660159
commit
4334a67924
|
@ -7,8 +7,10 @@ from data.util import find_audio_files
|
|||
|
||||
|
||||
class SpleeterDataset(Dataset):
|
||||
def __init__(self, src_dir, sample_rate=22050, max_duration=20):
|
||||
def __init__(self, src_dir, sample_rate=22050, max_duration=20, skip=0):
|
||||
self.files = find_audio_files(src_dir, include_nonwav=True)
|
||||
if skip > 0:
|
||||
self.files = self.files[skip:]
|
||||
self.audio_loader = AudioAdapter.default()
|
||||
self.sample_rate = sample_rate
|
||||
self.max_duration = max_duration
|
||||
|
|
|
@ -17,11 +17,9 @@ def main():
|
|||
output_sample_rate=22050
|
||||
batch_size=16
|
||||
|
||||
dl = DataLoader(SpleeterDataset(src_dir, output_sample_rate), batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)
|
||||
dl = DataLoader(SpleeterDataset(src_dir, output_sample_rate, skip=batch_size*33000), batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)
|
||||
separator = Separator('pretrained_models/2stems', input_sr=output_sample_rate)
|
||||
for e, batch in enumerate(tqdm(dl)):
|
||||
#if e < 406500:
|
||||
# continue
|
||||
for batch in tqdm(dl):
|
||||
waves = batch['wave']
|
||||
paths = batch['path']
|
||||
durations = batch['duration']
|
||||
|
|
Loading…
Reference in New Issue
Block a user