From 4100469902a5442f61c346b349a6876c346586d7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 8 Aug 2021 23:23:13 -0600 Subject: [PATCH] Add a tool to split mp3 files into arbitrary chunks of wav files --- codes/data/audio/random_mp3_splitter.py | 41 +++++++++++++++ codes/data/audio/wavfile_dataset.py | 4 +- codes/data/chunk_with_reference.py | 2 +- codes/data/full_image_dataset.py | 6 +-- codes/data/image_folder_dataset.py | 2 +- codes/data/multiscale_dataset.py | 2 +- codes/data/util.py | 54 +++++++++++--------- codes/scripts/audio/test_audio_similarity.py | 4 +- codes/scripts/extract_square_images.py | 2 +- codes/scripts/extract_temporal_squares.py | 2 +- 10 files changed, 82 insertions(+), 37 deletions(-) create mode 100644 codes/data/audio/random_mp3_splitter.py diff --git a/codes/data/audio/random_mp3_splitter.py b/codes/data/audio/random_mp3_splitter.py new file mode 100644 index 00000000..4170f0a1 --- /dev/null +++ b/codes/data/audio/random_mp3_splitter.py @@ -0,0 +1,41 @@ +import audio2numpy +from scipy.io import wavfile +from tqdm import tqdm + +from data.util import find_audio_files +import numpy as np +import torch +import torch.nn.functional as F +import os.path as osp + +if __name__ == '__main__': + src_dir = 'E:\\audio\\books' + clip_length = 5 # In seconds + sparsity = .05 # Only this proportion of the total clips are extracted as wavs. + output_sample_rate=22050 + output_dir = 'E:\\audio\\books-clips' + + files = find_audio_files(src_dir, include_nonwav=True) + for e, file in enumerate(tqdm(files)): + if e < 7250: + continue + file_basis = osp.relpath(file, src_dir).replace('/', '_').replace('\\', '_') + try: + wave, sample_rate = audio2numpy.open_audio(file) + except: + print(f"Error with {file}") + continue + wave = torch.tensor(wave) + # Strip out channels. + if len(wave.shape) > 1: + wave = wave[0] # Just use the first channel. + + # Calculate how much data we need to extract for each clip. + clip_sz = sample_rate * clip_length + interval = int(sample_rate * (clip_length / sparsity)) + i = 0 + while (i+clip_sz) < wave.shape[-1]: + clip = wave[i:i+clip_sz] + clip = F.interpolate(clip.view(1,1,clip_sz), scale_factor=output_sample_rate/sample_rate).squeeze() + wavfile.write(osp.join(output_dir, f'{file_basis}_{i}.wav'), output_sample_rate, clip.numpy()) + i = i + interval diff --git a/codes/data/audio/wavfile_dataset.py b/codes/data/audio/wavfile_dataset.py index 06c2913e..d8c73415 100644 --- a/codes/data/audio/wavfile_dataset.py +++ b/codes/data/audio/wavfile_dataset.py @@ -7,7 +7,7 @@ import torchaudio from tqdm import tqdm from data.audio.wav_aug import WavAugmentor -from data.util import get_image_paths, is_wav_file +from data.util import find_files_of_type, is_wav_file from models.tacotron2.taco_utils import load_wav_to_torch from utils.util import opt_get @@ -21,7 +21,7 @@ class WavfileDataset(torch.utils.data.Dataset): self.audiopaths = torch.load(cache_path) else: print("Building cache..") - self.audiopaths = get_image_paths('img', opt['path'], qualifier=is_wav_file)[0] + self.audiopaths = find_files_of_type('img', opt['path'], qualifier=is_wav_file)[0] torch.save(self.audiopaths, cache_path) # Parse options diff --git a/codes/data/chunk_with_reference.py b/codes/data/chunk_with_reference.py index b9acebeb..c7bce448 100644 --- a/codes/data/chunk_with_reference.py +++ b/codes/data/chunk_with_reference.py @@ -10,7 +10,7 @@ from utils.util import opt_get class ChunkWithReference: def __init__(self, opt, path): self.path = path.path - self.tiles, _ = util.get_image_paths('img', self.path) + self.tiles, _ = util.find_files_of_type('img', self.path) self.need_metadata = opt_get(opt, ['strict'], False) or opt_get(opt, ['needs_metadata'], False) self.need_ref = opt_get(opt, ['need_ref'], False) if 'ignore_first' in opt.keys(): diff --git a/codes/data/full_image_dataset.py b/codes/data/full_image_dataset.py index fff6cf3f..92e14036 100644 --- a/codes/data/full_image_dataset.py +++ b/codes/data/full_image_dataset.py @@ -29,17 +29,17 @@ class FullImageDataset(data.Dataset): self.LQ_env, self.GT_env = None, None self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1 - self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights']) + self.paths_GT, self.sizes_GT = util.find_files_of_type(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights']) if 'dataroot_LQ' in opt.keys(): self.paths_LQ = [] if isinstance(opt['dataroot_LQ'], list): # Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and # we want the model to learn them all. for dr_lq in opt['dataroot_LQ']: - lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq) + lq_path, self.sizes_LQ = util.find_files_of_type(self.data_type, dr_lq) self.paths_LQ.append(lq_path) else: - lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) + lq_path, self.sizes_LQ = util.find_files_of_type(self.data_type, opt['dataroot_LQ']) self.paths_LQ.append(lq_path) assert self.paths_GT, 'Error: GT path is empty.' diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index 63db243e..712d04dd 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -85,7 +85,7 @@ class ImageFolderDataset: imgs = torch.load(cache_path) else: print("Building image folder cache, this can take some time for large datasets..") - imgs = util.get_image_paths('img', path)[0] + imgs = util.find_files_of_type('img', path)[0] torch.save(imgs, cache_path) for w in range(weight): self.image_paths.extend(imgs) diff --git a/codes/data/multiscale_dataset.py b/codes/data/multiscale_dataset.py index 6673f6e0..2314e3e6 100644 --- a/codes/data/multiscale_dataset.py +++ b/codes/data/multiscale_dataset.py @@ -40,7 +40,7 @@ class MultiScaleDataset(data.Dataset): self.num_scales = self.opt['num_scales'] self.hq_size_cap = self.tile_size * 2 ** self.num_scales self.scale = self.opt['scale'] - self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1 for _ in opt['paths']]) + self.paths_hq, self.sizes_hq = util.find_files_of_type(self.data_type, opt['paths'], [1 for _ in opt['paths']]) self.corruptor = ImageCorruptor(opt) diff --git a/codes/data/util.py b/codes/data/util.py index 23b072bb..7c4c9e11 100644 --- a/codes/data/util.py +++ b/codes/data/util.py @@ -39,10 +39,16 @@ def cv2torch(cv, batchify=True): def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + def is_wav_file(filename): return filename.endswith('.wav') +def is_audio_file(filename): + AUDIO_EXTENSIONS = ['.wav', '.mp3', '.wma', 'm4b'] + return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS) + + def _get_paths_from_images(path, qualifier=is_image_file): """get image path list from image folder""" assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) @@ -67,32 +73,30 @@ def _get_paths_from_lmdb(dataroot): return paths, sizes -def get_image_paths(data_type, dataroot, weights=[], qualifier=is_image_file): - """get image path list - support lmdb or image files""" - paths, sizes = None, None - if dataroot is not None: - if data_type == 'lmdb': - paths, sizes = _get_paths_from_lmdb(dataroot) - elif data_type == 'img': - if isinstance(dataroot, list): - paths = [] - for i in range(len(dataroot)): - r = dataroot[i] - extends = 1 +def find_audio_files(dataroot, include_nonwav=False): + if include_nonwav: + return find_files_of_type(None, dataroot, qualifier=is_audio_file)[0] + else: + return find_files_of_type(None, dataroot, qualifier=is_wav_file)[0] - # Weights have the effect of repeatedly adding the paths from the given root to the final product. - if weights: - extends = weights[i] - for j in range(extends): - paths.extend(_get_paths_from_images(r, qualifier)) - paths = sorted(paths) - sizes = len(paths) - else: - paths = sorted(_get_paths_from_images(dataroot, qualifier)) - sizes = len(paths) - else: - raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) + +def find_files_of_type(data_type, dataroot, weights=[], qualifier=is_image_file): + if isinstance(dataroot, list): + paths = [] + for i in range(len(dataroot)): + r = dataroot[i] + extends = 1 + + # Weights have the effect of repeatedly adding the paths from the given root to the final product. + if weights: + extends = weights[i] + for j in range(extends): + paths.extend(_get_paths_from_images(r, qualifier)) + paths = sorted(paths) + sizes = len(paths) + else: + paths = sorted(_get_paths_from_images(dataroot, qualifier)) + sizes = len(paths) return paths, sizes diff --git a/codes/scripts/audio/test_audio_similarity.py b/codes/scripts/audio/test_audio_similarity.py index a7afc5bf..297fb4e4 100644 --- a/codes/scripts/audio/test_audio_similarity.py +++ b/codes/scripts/audio/test_audio_similarity.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from data.util import is_wav_file, get_image_paths +from data.util import is_wav_file, find_files_of_type from models.audio_resnet import resnet34, resnet50 from models.tacotron2.taco_utils import load_wav_to_torch from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict @@ -12,7 +12,7 @@ from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_stat if __name__ == '__main__': window = 48000 root_path = 'D:\\tmp\\clips' - paths = get_image_paths('img', root_path, qualifier=is_wav_file)[0] + paths = find_files_of_type('img', root_path, qualifier=is_wav_file)[0] clips = [] for path in paths: clip, sr = load_wav_to_torch(os.path.join(root_path, path)) diff --git a/codes/scripts/extract_square_images.py b/codes/scripts/extract_square_images.py index a8f9eff8..6b792e61 100644 --- a/codes/scripts/extract_square_images.py +++ b/codes/scripts/extract_square_images.py @@ -39,7 +39,7 @@ class TiledDataset(data.Dataset): def __init__(self, opt): self.opt = opt input_folder = opt['input_folder'] - self.images = data_util.get_image_paths('img', input_folder)[0] + self.images = data_util.find_files_of_type('img', input_folder)[0] print("Found %i images" % (len(self.images),)) def __getitem__(self, index): diff --git a/codes/scripts/extract_temporal_squares.py b/codes/scripts/extract_temporal_squares.py index bb74f7ed..b8e167a9 100644 --- a/codes/scripts/extract_temporal_squares.py +++ b/codes/scripts/extract_temporal_squares.py @@ -115,7 +115,7 @@ class VideoClipDataset(data.Dataset): os.makedirs(frames_out, exist_ok=False) n = random.randint(5, 30) self.extract_n_frames(path, frames_out, start, n) - frames = data_util.get_image_paths('img', frames_out)[0] + frames = data_util.find_files_of_type('img', frames_out)[0] assert len(frames) == n img_runs.append(([self.get_image_tensor(frame) for frame in frames], frames_out)) start += random.randint(2,5)