Spleeter mods

This commit is contained in:
James Betker 2021-09-14 17:43:40 -06:00
parent 0382660159
commit 4334a67924
2 changed files with 5 additions and 5 deletions

View File

@ -7,8 +7,10 @@ from data.util import find_audio_files
class SpleeterDataset(Dataset): class SpleeterDataset(Dataset):
def __init__(self, src_dir, sample_rate=22050, max_duration=20): def __init__(self, src_dir, sample_rate=22050, max_duration=20, skip=0):
self.files = find_audio_files(src_dir, include_nonwav=True) self.files = find_audio_files(src_dir, include_nonwav=True)
if skip > 0:
self.files = self.files[skip:]
self.audio_loader = AudioAdapter.default() self.audio_loader = AudioAdapter.default()
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.max_duration = max_duration self.max_duration = max_duration

View File

@ -17,11 +17,9 @@ def main():
output_sample_rate=22050 output_sample_rate=22050
batch_size=16 batch_size=16
dl = DataLoader(SpleeterDataset(src_dir, output_sample_rate), batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) dl = DataLoader(SpleeterDataset(src_dir, output_sample_rate, skip=batch_size*33000), batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)
separator = Separator('pretrained_models/2stems', input_sr=output_sample_rate) separator = Separator('pretrained_models/2stems', input_sr=output_sample_rate)
for e, batch in enumerate(tqdm(dl)): for batch in tqdm(dl):
#if e < 406500:
# continue
waves = batch['wave'] waves = batch['wave']
paths = batch['path'] paths = batch['path']
durations = batch['duration'] durations = batch['duration']