From 7a9c4310e88050e0d28084aac5e7c0d477c6eb45 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 23 Jun 2022 11:39:10 -0600 Subject: [PATCH] support reading cheaters directly --- codes/data/audio/preprocessed_mel_dataset.py | 9 ++- .../prep_music/generate_long_cheaters.py | 67 +++++++++++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) create mode 100644 codes/scripts/audio/prep_music/generate_long_cheaters.py diff --git a/codes/data/audio/preprocessed_mel_dataset.py b/codes/data/audio/preprocessed_mel_dataset.py index 8667ed88..a88e704e 100644 --- a/codes/data/audio/preprocessed_mel_dataset.py +++ b/codes/data/audio/preprocessed_mel_dataset.py @@ -25,11 +25,14 @@ class PreprocessedMelDataset(torch.utils.data.Dataset): self.paths = [str(p) for p in path.rglob("*.npz")] torch.save(self.paths, cache_path) self.pad_to = opt_get(opt, ['pad_to_samples'], 10336) + self.squeeze = opt_get(opt, ['should_squeeze'], False) def __getitem__(self, index): with np.load(self.paths[index]) as npz_file: mel = torch.tensor(npz_file['arr_0']) assert mel.shape[-1] <= self.pad_to + if self.squeeze: + mel = mel.squeeze() padding_needed = self.pad_to - mel.shape[-1] mask = torch.zeros_like(mel) if padding_needed > 0: @@ -52,9 +55,9 @@ class PreprocessedMelDataset(torch.utils.data.Dataset): if __name__ == '__main__': params = { 'mode': 'preprocessed_mel', - 'path': 'Y:\\separated\\large_mels', - 'cache_path': 'Y:\\separated\\large_mels.pth', - 'pad_to_samples': 10336, + 'path': 'Y:\\separated\\large_mel_cheaters', + 'cache_path': 'Y:\\separated\\large_mel_cheaters_win.pth', + 'pad_to_samples': 646, 'phase': 'train', 'n_workers': 0, 'batch_size': 16, diff --git a/codes/scripts/audio/prep_music/generate_long_cheaters.py b/codes/scripts/audio/prep_music/generate_long_cheaters.py new file mode 100644 index 00000000..2ee11f6a --- /dev/null +++ b/codes/scripts/audio/prep_music/generate_long_cheaters.py @@ -0,0 +1,67 @@ + + +""" +Master script that processes all MP3 files found in an input directory. Splits those files up into sub-files of a +predetermined duration. +""" +import argparse +import functools +import os +from multiprocessing.pool import ThreadPool +from pathlib import Path + +import numpy as np +import torch +from tqdm import tqdm + +from trainer.injectors.audio_injectors import MusicCheaterLatentInjector + + +def report_progress(progress_file, file): + with open(progress_file, 'a', encoding='utf-8') as f: + f.write(f'{file}\n') + + +cheater_inj = MusicCheaterLatentInjector({'in': 'in', 'out': 'out'}, {}) + + +def process_folder(file, base_path, output_path, progress_file): + outdir = os.path.join(output_path, f'{os.path.relpath(os.path.dirname(file), base_path)}') + os.makedirs(outdir, exist_ok=True) + with np.load(file) as npz_file: + mel = torch.tensor(npz_file['arr_0']).cuda().unsqueeze(0) + cheater = cheater_inj({'in': mel})['out'] + np.savez(os.path.join(outdir, os.path.basename(file)), cheater.cpu().numpy()) + report_progress(progress_file, file) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--path', type=str, help='Path to search for files', default='Y:\\separated\\large_mels') + parser.add_argument('--progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\separated\\large_mel_cheaters\\already_processed.txt') + parser.add_argument('--output_path', type=str, help='Path for output files', default='Y:\\separated\\large_mel_cheaters') + parser.add_argument('--num_threads', type=int, help='Number of concurrent workers processing files.', default=1) + args = parser.parse_args() + + os.makedirs(args.output_path, exist_ok=True) + processed_files = set() + if os.path.exists(args.progress_file): + with open(args.progress_file, 'r', encoding='utf-8') as f: + for line in f.readlines(): + processed_files.add(line.strip()) + + cache_path = os.path.join(args.output_path, 'cache.pth') + if os.path.exists(cache_path): + root_music_files = torch.load(cache_path) + else: + path = Path(args.path) + root_music_files = set(path.rglob("*.npz")) + torch.save(root_music_files, cache_path) + + orig_len = len(root_music_files) + folders = root_music_files - processed_files + print(f"Found {len(folders)} files to process. Total processing is {100 * (orig_len - len(folders)) / orig_len}% complete.") + + with ThreadPool(args.num_threads) as pool: + list(tqdm(pool.imap(functools.partial(process_folder, output_path=args.output_path, base_path=args.path, + progress_file=args.progress_file), folders), total=len(folders)))