support reading cheaters directly
This commit is contained in:
parent
b210e5025c
commit
7a9c4310e8
|
@ -25,11 +25,14 @@ class PreprocessedMelDataset(torch.utils.data.Dataset):
|
|||
self.paths = [str(p) for p in path.rglob("*.npz")]
|
||||
torch.save(self.paths, cache_path)
|
||||
self.pad_to = opt_get(opt, ['pad_to_samples'], 10336)
|
||||
self.squeeze = opt_get(opt, ['should_squeeze'], False)
|
||||
|
||||
def __getitem__(self, index):
|
||||
with np.load(self.paths[index]) as npz_file:
|
||||
mel = torch.tensor(npz_file['arr_0'])
|
||||
assert mel.shape[-1] <= self.pad_to
|
||||
if self.squeeze:
|
||||
mel = mel.squeeze()
|
||||
padding_needed = self.pad_to - mel.shape[-1]
|
||||
mask = torch.zeros_like(mel)
|
||||
if padding_needed > 0:
|
||||
|
@ -52,9 +55,9 @@ class PreprocessedMelDataset(torch.utils.data.Dataset):
|
|||
if __name__ == '__main__':
|
||||
params = {
|
||||
'mode': 'preprocessed_mel',
|
||||
'path': 'Y:\\separated\\large_mels',
|
||||
'cache_path': 'Y:\\separated\\large_mels.pth',
|
||||
'pad_to_samples': 10336,
|
||||
'path': 'Y:\\separated\\large_mel_cheaters',
|
||||
'cache_path': 'Y:\\separated\\large_mel_cheaters_win.pth',
|
||||
'pad_to_samples': 646,
|
||||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
'batch_size': 16,
|
||||
|
|
67
codes/scripts/audio/prep_music/generate_long_cheaters.py
Normal file
67
codes/scripts/audio/prep_music/generate_long_cheaters.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
|
||||
|
||||
"""
|
||||
Master script that processes all MP3 files found in an input directory. Splits those files up into sub-files of a
|
||||
predetermined duration.
|
||||
"""
|
||||
import argparse
|
||||
import functools
|
||||
import os
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from trainer.injectors.audio_injectors import MusicCheaterLatentInjector
|
||||
|
||||
|
||||
def report_progress(progress_file, file):
|
||||
with open(progress_file, 'a', encoding='utf-8') as f:
|
||||
f.write(f'{file}\n')
|
||||
|
||||
|
||||
cheater_inj = MusicCheaterLatentInjector({'in': 'in', 'out': 'out'}, {})
|
||||
|
||||
|
||||
def process_folder(file, base_path, output_path, progress_file):
|
||||
outdir = os.path.join(output_path, f'{os.path.relpath(os.path.dirname(file), base_path)}')
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
with np.load(file) as npz_file:
|
||||
mel = torch.tensor(npz_file['arr_0']).cuda().unsqueeze(0)
|
||||
cheater = cheater_inj({'in': mel})['out']
|
||||
np.savez(os.path.join(outdir, os.path.basename(file)), cheater.cpu().numpy())
|
||||
report_progress(progress_file, file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--path', type=str, help='Path to search for files', default='Y:\\separated\\large_mels')
|
||||
parser.add_argument('--progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\separated\\large_mel_cheaters\\already_processed.txt')
|
||||
parser.add_argument('--output_path', type=str, help='Path for output files', default='Y:\\separated\\large_mel_cheaters')
|
||||
parser.add_argument('--num_threads', type=int, help='Number of concurrent workers processing files.', default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
processed_files = set()
|
||||
if os.path.exists(args.progress_file):
|
||||
with open(args.progress_file, 'r', encoding='utf-8') as f:
|
||||
for line in f.readlines():
|
||||
processed_files.add(line.strip())
|
||||
|
||||
cache_path = os.path.join(args.output_path, 'cache.pth')
|
||||
if os.path.exists(cache_path):
|
||||
root_music_files = torch.load(cache_path)
|
||||
else:
|
||||
path = Path(args.path)
|
||||
root_music_files = set(path.rglob("*.npz"))
|
||||
torch.save(root_music_files, cache_path)
|
||||
|
||||
orig_len = len(root_music_files)
|
||||
folders = root_music_files - processed_files
|
||||
print(f"Found {len(folders)} files to process. Total processing is {100 * (orig_len - len(folders)) / orig_len}% complete.")
|
||||
|
||||
with ThreadPool(args.num_threads) as pool:
|
||||
list(tqdm(pool.imap(functools.partial(process_folder, output_path=args.output_path, base_path=args.path,
|
||||
progress_file=args.progress_file), folders), total=len(folders)))
|
Loading…
Reference in New Issue
Block a user