Add a tool to split mp3 files into arbitrary chunks of wav files

This commit is contained in:
James Betker 2021-08-08 23:23:13 -06:00
parent 01cfae28d8
commit 4100469902
10 changed files with 82 additions and 37 deletions

View File

@ -0,0 +1,41 @@
import audio2numpy
from scipy.io import wavfile
from tqdm import tqdm
from data.util import find_audio_files
import numpy as np
import torch
import torch.nn.functional as F
import os.path as osp
if __name__ == '__main__':
src_dir = 'E:\\audio\\books'
clip_length = 5 # In seconds
sparsity = .05 # Only this proportion of the total clips are extracted as wavs.
output_sample_rate=22050
output_dir = 'E:\\audio\\books-clips'
files = find_audio_files(src_dir, include_nonwav=True)
for e, file in enumerate(tqdm(files)):
if e < 7250:
continue
file_basis = osp.relpath(file, src_dir).replace('/', '_').replace('\\', '_')
try:
wave, sample_rate = audio2numpy.open_audio(file)
except:
print(f"Error with {file}")
continue
wave = torch.tensor(wave)
# Strip out channels.
if len(wave.shape) > 1:
wave = wave[0] # Just use the first channel.
# Calculate how much data we need to extract for each clip.
clip_sz = sample_rate * clip_length
interval = int(sample_rate * (clip_length / sparsity))
i = 0
while (i+clip_sz) < wave.shape[-1]:
clip = wave[i:i+clip_sz]
clip = F.interpolate(clip.view(1,1,clip_sz), scale_factor=output_sample_rate/sample_rate).squeeze()
wavfile.write(osp.join(output_dir, f'{file_basis}_{i}.wav'), output_sample_rate, clip.numpy())
i = i + interval

View File

@ -7,7 +7,7 @@ import torchaudio
from tqdm import tqdm from tqdm import tqdm
from data.audio.wav_aug import WavAugmentor from data.audio.wav_aug import WavAugmentor
from data.util import get_image_paths, is_wav_file from data.util import find_files_of_type, is_wav_file
from models.tacotron2.taco_utils import load_wav_to_torch from models.tacotron2.taco_utils import load_wav_to_torch
from utils.util import opt_get from utils.util import opt_get
@ -21,7 +21,7 @@ class WavfileDataset(torch.utils.data.Dataset):
self.audiopaths = torch.load(cache_path) self.audiopaths = torch.load(cache_path)
else: else:
print("Building cache..") print("Building cache..")
self.audiopaths = get_image_paths('img', opt['path'], qualifier=is_wav_file)[0] self.audiopaths = find_files_of_type('img', opt['path'], qualifier=is_wav_file)[0]
torch.save(self.audiopaths, cache_path) torch.save(self.audiopaths, cache_path)
# Parse options # Parse options

View File

@ -10,7 +10,7 @@ from utils.util import opt_get
class ChunkWithReference: class ChunkWithReference:
def __init__(self, opt, path): def __init__(self, opt, path):
self.path = path.path self.path = path.path
self.tiles, _ = util.get_image_paths('img', self.path) self.tiles, _ = util.find_files_of_type('img', self.path)
self.need_metadata = opt_get(opt, ['strict'], False) or opt_get(opt, ['needs_metadata'], False) self.need_metadata = opt_get(opt, ['strict'], False) or opt_get(opt, ['needs_metadata'], False)
self.need_ref = opt_get(opt, ['need_ref'], False) self.need_ref = opt_get(opt, ['need_ref'], False)
if 'ignore_first' in opt.keys(): if 'ignore_first' in opt.keys():

View File

@ -29,17 +29,17 @@ class FullImageDataset(data.Dataset):
self.LQ_env, self.GT_env = None, None self.LQ_env, self.GT_env = None, None
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1 self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights']) self.paths_GT, self.sizes_GT = util.find_files_of_type(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights'])
if 'dataroot_LQ' in opt.keys(): if 'dataroot_LQ' in opt.keys():
self.paths_LQ = [] self.paths_LQ = []
if isinstance(opt['dataroot_LQ'], list): if isinstance(opt['dataroot_LQ'], list):
# Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and # Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and
# we want the model to learn them all. # we want the model to learn them all.
for dr_lq in opt['dataroot_LQ']: for dr_lq in opt['dataroot_LQ']:
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq) lq_path, self.sizes_LQ = util.find_files_of_type(self.data_type, dr_lq)
self.paths_LQ.append(lq_path) self.paths_LQ.append(lq_path)
else: else:
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) lq_path, self.sizes_LQ = util.find_files_of_type(self.data_type, opt['dataroot_LQ'])
self.paths_LQ.append(lq_path) self.paths_LQ.append(lq_path)
assert self.paths_GT, 'Error: GT path is empty.' assert self.paths_GT, 'Error: GT path is empty.'

View File

@ -85,7 +85,7 @@ class ImageFolderDataset:
imgs = torch.load(cache_path) imgs = torch.load(cache_path)
else: else:
print("Building image folder cache, this can take some time for large datasets..") print("Building image folder cache, this can take some time for large datasets..")
imgs = util.get_image_paths('img', path)[0] imgs = util.find_files_of_type('img', path)[0]
torch.save(imgs, cache_path) torch.save(imgs, cache_path)
for w in range(weight): for w in range(weight):
self.image_paths.extend(imgs) self.image_paths.extend(imgs)

View File

@ -40,7 +40,7 @@ class MultiScaleDataset(data.Dataset):
self.num_scales = self.opt['num_scales'] self.num_scales = self.opt['num_scales']
self.hq_size_cap = self.tile_size * 2 ** self.num_scales self.hq_size_cap = self.tile_size * 2 ** self.num_scales
self.scale = self.opt['scale'] self.scale = self.opt['scale']
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1 for _ in opt['paths']]) self.paths_hq, self.sizes_hq = util.find_files_of_type(self.data_type, opt['paths'], [1 for _ in opt['paths']])
self.corruptor = ImageCorruptor(opt) self.corruptor = ImageCorruptor(opt)

View File

@ -39,10 +39,16 @@ def cv2torch(cv, batchify=True):
def is_image_file(filename): def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def is_wav_file(filename): def is_wav_file(filename):
return filename.endswith('.wav') return filename.endswith('.wav')
def is_audio_file(filename):
AUDIO_EXTENSIONS = ['.wav', '.mp3', '.wma', 'm4b']
return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS)
def _get_paths_from_images(path, qualifier=is_image_file): def _get_paths_from_images(path, qualifier=is_image_file):
"""get image path list from image folder""" """get image path list from image folder"""
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
@ -67,14 +73,14 @@ def _get_paths_from_lmdb(dataroot):
return paths, sizes return paths, sizes
def get_image_paths(data_type, dataroot, weights=[], qualifier=is_image_file): def find_audio_files(dataroot, include_nonwav=False):
"""get image path list if include_nonwav:
support lmdb or image files""" return find_files_of_type(None, dataroot, qualifier=is_audio_file)[0]
paths, sizes = None, None else:
if dataroot is not None: return find_files_of_type(None, dataroot, qualifier=is_wav_file)[0]
if data_type == 'lmdb':
paths, sizes = _get_paths_from_lmdb(dataroot)
elif data_type == 'img': def find_files_of_type(data_type, dataroot, weights=[], qualifier=is_image_file):
if isinstance(dataroot, list): if isinstance(dataroot, list):
paths = [] paths = []
for i in range(len(dataroot)): for i in range(len(dataroot)):
@ -91,8 +97,6 @@ def get_image_paths(data_type, dataroot, weights=[], qualifier=is_image_file):
else: else:
paths = sorted(_get_paths_from_images(dataroot, qualifier)) paths = sorted(_get_paths_from_images(dataroot, qualifier))
sizes = len(paths) sizes = len(paths)
else:
raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
return paths, sizes return paths, sizes

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from data.util import is_wav_file, get_image_paths from data.util import is_wav_file, find_files_of_type
from models.audio_resnet import resnet34, resnet50 from models.audio_resnet import resnet34, resnet50
from models.tacotron2.taco_utils import load_wav_to_torch from models.tacotron2.taco_utils import load_wav_to_torch
from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict
@ -12,7 +12,7 @@ from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_stat
if __name__ == '__main__': if __name__ == '__main__':
window = 48000 window = 48000
root_path = 'D:\\tmp\\clips' root_path = 'D:\\tmp\\clips'
paths = get_image_paths('img', root_path, qualifier=is_wav_file)[0] paths = find_files_of_type('img', root_path, qualifier=is_wav_file)[0]
clips = [] clips = []
for path in paths: for path in paths:
clip, sr = load_wav_to_torch(os.path.join(root_path, path)) clip, sr = load_wav_to_torch(os.path.join(root_path, path))

View File

@ -39,7 +39,7 @@ class TiledDataset(data.Dataset):
def __init__(self, opt): def __init__(self, opt):
self.opt = opt self.opt = opt
input_folder = opt['input_folder'] input_folder = opt['input_folder']
self.images = data_util.get_image_paths('img', input_folder)[0] self.images = data_util.find_files_of_type('img', input_folder)[0]
print("Found %i images" % (len(self.images),)) print("Found %i images" % (len(self.images),))
def __getitem__(self, index): def __getitem__(self, index):

View File

@ -115,7 +115,7 @@ class VideoClipDataset(data.Dataset):
os.makedirs(frames_out, exist_ok=False) os.makedirs(frames_out, exist_ok=False)
n = random.randint(5, 30) n = random.randint(5, 30)
self.extract_n_frames(path, frames_out, start, n) self.extract_n_frames(path, frames_out, start, n)
frames = data_util.get_image_paths('img', frames_out)[0] frames = data_util.find_files_of_type('img', frames_out)[0]
assert len(frames) == n assert len(frames) == n
img_runs.append(([self.get_image_tensor(frame) for frame in frames], frames_out)) img_runs.append(([self.get_image_tensor(frame) for frame in frames], frames_out))
start += random.randint(2,5) start += random.randint(2,5)