forked from mrq/DL-Art-School
37 lines
1.3 KiB
Python
37 lines
1.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from spleeter.audio.adapter import AudioAdapter
|
|
from torch.utils.data import Dataset
|
|
|
|
from data.util import find_audio_files
|
|
|
|
|
|
class SpleeterDataset(Dataset):
|
|
def __init__(self, src_dir, sample_rate=22050, max_duration=20):
|
|
self.files = find_audio_files(src_dir, include_nonwav=True)
|
|
self.audio_loader = AudioAdapter.default()
|
|
self.sample_rate = sample_rate
|
|
self.max_duration = max_duration
|
|
|
|
def __getitem__(self, item):
|
|
file = self.files[item]
|
|
try:
|
|
wave, sample_rate = self.audio_loader.load(file, sample_rate=self.sample_rate)
|
|
assert sample_rate == self.sample_rate
|
|
wave = wave[:,0] # strip off channels
|
|
wave = torch.tensor(wave)
|
|
except:
|
|
wave = torch.zeros(self.sample_rate * self.max_duration)
|
|
print(f"Error with {file}")
|
|
original_duration = wave.shape[0]
|
|
padding_needed = self.sample_rate * self.max_duration - original_duration
|
|
if padding_needed > 0:
|
|
wave = nn.functional.pad(wave, (0, padding_needed))
|
|
return {
|
|
'path': file,
|
|
'wave': wave,
|
|
'duration': original_duration,
|
|
}
|
|
|
|
def __len__(self):
|
|
return len(self.files) |