Spleeter mods
This commit is contained in:
parent
0382660159
commit
4334a67924
|
@ -7,8 +7,10 @@ from data.util import find_audio_files
|
||||||
|
|
||||||
|
|
||||||
class SpleeterDataset(Dataset):
|
class SpleeterDataset(Dataset):
|
||||||
def __init__(self, src_dir, sample_rate=22050, max_duration=20):
|
def __init__(self, src_dir, sample_rate=22050, max_duration=20, skip=0):
|
||||||
self.files = find_audio_files(src_dir, include_nonwav=True)
|
self.files = find_audio_files(src_dir, include_nonwav=True)
|
||||||
|
if skip > 0:
|
||||||
|
self.files = self.files[skip:]
|
||||||
self.audio_loader = AudioAdapter.default()
|
self.audio_loader = AudioAdapter.default()
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.max_duration = max_duration
|
self.max_duration = max_duration
|
||||||
|
|
|
@ -17,11 +17,9 @@ def main():
|
||||||
output_sample_rate=22050
|
output_sample_rate=22050
|
||||||
batch_size=16
|
batch_size=16
|
||||||
|
|
||||||
dl = DataLoader(SpleeterDataset(src_dir, output_sample_rate), batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)
|
dl = DataLoader(SpleeterDataset(src_dir, output_sample_rate, skip=batch_size*33000), batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)
|
||||||
separator = Separator('pretrained_models/2stems', input_sr=output_sample_rate)
|
separator = Separator('pretrained_models/2stems', input_sr=output_sample_rate)
|
||||||
for e, batch in enumerate(tqdm(dl)):
|
for batch in tqdm(dl):
|
||||||
#if e < 406500:
|
|
||||||
# continue
|
|
||||||
waves = batch['wave']
|
waves = batch['wave']
|
||||||
paths = batch['path']
|
paths = batch['path']
|
||||||
durations = batch['duration']
|
durations = batch['duration']
|
||||||
|
|
Loading…
Reference in New Issue
Block a user