From 368dca18b14fec8a33882b5992da4c787888313e Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 19 Jun 2022 15:07:24 -0600 Subject: [PATCH 1/4] mdf fixes + support for tfd-based waveform gen --- codes/trainer/eval/music_diffusion_fid.py | 43 ++++++++--------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index b41f7572..066b2e28 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -79,24 +79,18 @@ class MusicDiffusionFid(evaluator.Evaluator): return list(glob(f'{path}/*.wav')) def perform_diffusion_spec_decode(self, audio, sample_rate=22050): - if sample_rate != sample_rate: - real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) - else: - real_resampled = audio + real_resampled = audio audio = audio.unsqueeze(0) - output_shape = (1, 16, audio.shape[-1] // 16) + output_shape = (1, 256, audio.shape[-1] // 256) mel = self.spec_fn({'in': audio})['out'] gen = self.diffuser.p_sample_loop(self.model, output_shape, - model_kwargs={'aligned_conditioning': mel}) - gen = pixel_shuffle_1d(gen, 16) + model_kwargs={'codes': mel}) + gen = pixel_shuffle_1d(gen, 256) return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate def perform_diffusion_from_codes(self, audio, sample_rate=22050): - if sample_rate != sample_rate: - real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) - else: - real_resampled = audio + real_resampled = audio audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] @@ -116,10 +110,7 @@ class MusicDiffusionFid(evaluator.Evaluator): return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050): - if sample_rate != sample_rate: - real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) - else: - real_resampled = audio + real_resampled = audio audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] @@ -148,10 +139,7 @@ class MusicDiffusionFid(evaluator.Evaluator): return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate def perform_partial_diffusion_from_codes_quant(self, audio, sample_rate=22050, partial_low=0, partial_high=256): - if sample_rate != sample_rate: - real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) - else: - real_resampled = audio + real_resampled = audio audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] @@ -174,10 +162,7 @@ class MusicDiffusionFid(evaluator.Evaluator): return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate def perform_diffusion_from_codes_quant_gradual_decode(self, audio, sample_rate=22050): - if sample_rate != sample_rate: - real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) - else: - real_resampled = audio + real_resampled = audio audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] @@ -273,17 +258,17 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': - diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_quant.yml', 'generator', + diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41500_generator_ema.pth' + load_path='X:\\dlas\\experiments\\train_music_waveform_gen_retry\\models\\22000_generator_ema.pth' ).cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. - 'diffusion_steps': 200, - 'conditioning_free': True, 'conditioning_free_k': 2, - 'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant', + 'diffusion_steps': 100, + 'conditioning_free': False, 'conditioning_free_k': 1, + 'diffusion_schedule': 'linear', 'diffusion_type': 'spec_decode', #'partial_low': 128, 'partial_high': 192 } - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 605, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 100, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval()) From 8c8efbe1319c8f438300bffea9c9a5a00518279a Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 19 Jun 2022 17:54:08 -0600 Subject: [PATCH 2/4] fix code_emb --- .../audio/music/transformer_diffusion12.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 58a11870..e1b748ca 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -210,8 +210,8 @@ class TransformerDiffusion(nn.Module): def timestep_independent(self, prior, expected_seq_len): if self.new_code_expansion: - code_emb = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1) - code_emb = self.ar_input(code_emb) if self.ar_prior else self.input_converter(code_emb) + prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1) + code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior) code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. @@ -732,14 +732,14 @@ def test_cheater_model(): # For music: model = TransformerDiffusionWithCheaterLatent(in_channels=256, out_channels=512, - model_channels=1024, contraction_dim=512, - prenet_channels=1024, num_heads=8, - input_vec_dim=256, num_layers=12, prenet_layers=6, + model_channels=1536, contraction_dim=768, + prenet_channels=1024, num_heads=12, + input_vec_dim=256, num_layers=20, prenet_layers=6, dropout=.1, new_code_expansion=True, ) - diff_weights = torch.load('extracted_diff.pth') - model.diff.load_state_dict(diff_weights, strict=False) - cheater_ar_weights = torch.load('X:\\dlas\\experiments\\train_music_gpt_cheater\\models\\19500_generator_ema.pth') + #diff_weights = torch.load('extracted_diff.pth') + #model.diff.load_state_dict(diff_weights, strict=False) + cheater_ar_weights = torch.load('X:\\dlas\\experiments\\train_music_gpt_cheater\\models\\60000_generator_ema.pth') cheater_ar = GptMusicLower(dim=1024, encoder_out_dim=256, layers=16, fp16=False, num_target_vectors=8192, num_vaes=4, vqargs= {'positional_dims': 1, 'channels': 64, 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192, From 90b232f9656feb070aa572c3ce69f16e6d21abce Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 19 Jun 2022 17:54:37 -0600 Subject: [PATCH 3/4] gen_long_mels --- .../audio/prep_music/generate_long_mels.py | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 codes/scripts/audio/prep_music/generate_long_mels.py diff --git a/codes/scripts/audio/prep_music/generate_long_mels.py b/codes/scripts/audio/prep_music/generate_long_mels.py new file mode 100644 index 00000000..00a74075 --- /dev/null +++ b/codes/scripts/audio/prep_music/generate_long_mels.py @@ -0,0 +1,97 @@ + + +""" +Master script that processes all MP3 files found in an input directory. Splits those files up into sub-files of a +predetermined duration. +""" +import argparse +import functools +import os +from multiprocessing.pool import ThreadPool +from pathlib import Path + +import torch +import torchaudio +import numpy as np +from tqdm import tqdm + +from data.util import find_audio_files, find_files_of_type +from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector +from utils.util import load_audio + + +def report_progress(progress_file, file): + with open(progress_file, 'a', encoding='utf-8') as f: + f.write(f'{file}\n') + + +spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, + 'true_normalization': True, 'normalize': True, 'in': 'in', 'out': 'out'}, {}).cuda() + +def produce_mel(audio): + return spec_fn({'in': audio.unsqueeze(0)})['out'].squeeze(0) + + +def process_folder(folder, base_path, output_path, progress_file, max_duration, sampling_rate=22050): + outdir = os.path.join(output_path, f'{os.path.relpath(folder, base_path)}') + os.makedirs(outdir, exist_ok=True) + + files = list(os.listdir(folder)) + i = 0 + output_i = 0 + while i < len(files): + last_ordinal = -1 + total_progress = 0 + to_combine = [] + while i < len(files) and total_progress < max_duration: + audio_file = os.path.join(folder, files[i], "no_vocals.wav") + if not os.path.exists(audio_file): + break + to_combine.append(load_audio(audio_file, 22050)) + file_ordinal = int(files[i]) + if last_ordinal != -1 and file_ordinal != last_ordinal+1: + break + else: + i += 1 + total_progress += 30 + if total_progress > 30: + combined = torch.cat(to_combine, dim=-1).cuda() + mel = produce_mel(combined).cpu().numpy() + np.savez(os.path.join(outdir, f'{output_i}'), mel) + output_i += 1 + report_progress(progress_file, folder) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--path', type=str, help='Path to search for files', default='Y:\\separated') + parser.add_argument('--progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\separated\\large_mels\\already_processed.txt') + parser.add_argument('--output_path', type=str, help='Path for output files', default='Y:\\separated\\large_mels') + parser.add_argument('--num_threads', type=int, help='Number of concurrent workers processing files.', default=3) + parser.add_argument('--max_duration', type=int, help='Duration per clip in seconds', default=120) + args = parser.parse_args() + + os.makedirs(args.output_path, exist_ok=True) + processed_files = set() + if os.path.exists(args.progress_file): + with open(args.progress_file, 'r', encoding='utf-8') as f: + for line in f.readlines(): + processed_files.add(line.strip()) + + cache_path = os.path.join(args.output_path, 'cache.pth') + if os.path.exists(cache_path): + root_music_files = torch.load(cache_path) + else: + path = Path(args.path) + def collect(p): + return str(os.path.dirname(os.path.dirname(p))) + root_music_files = {collect(p) for p in path.rglob("*no_vocals.wav")} + torch.save(root_music_files, cache_path) + + orig_len = len(root_music_files) + folders = root_music_files - processed_files + print(f"Found {len(folders)} files to process. Total processing is {100 * (orig_len - len(folders)) / orig_len}% complete.") + + with ThreadPool(args.num_threads) as pool: + list(tqdm(pool.imap(functools.partial(process_folder, output_path=args.output_path, base_path=args.path, + progress_file=args.progress_file, max_duration=args.max_duration), folders), total=len(folders))) From f425afc965253fb6096507f7c98a8f2bd50c9267 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 19 Jun 2022 18:00:30 -0600 Subject: [PATCH 4/4] permute codes --- .../models/audio/music/transformer_diffusion12.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index e1b748ca..208b4c6f 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -100,6 +100,7 @@ class TransformerDiffusion(nn.Module): use_fp16=False, ar_prior=False, new_code_expansion=False, + permute_codes=False, # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. # Parameters for re-training head @@ -116,6 +117,7 @@ class TransformerDiffusion(nn.Module): self.unconditioned_percentage = unconditioned_percentage self.enable_fp16 = use_fp16 self.new_code_expansion = new_code_expansion + self.permute_codes = permute_codes self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) @@ -228,6 +230,8 @@ class TransformerDiffusion(nn.Module): def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False): if precomputed_code_embeddings is not None: assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." + if self.permute_codes: + codes = codes.permute(0,2,1) unused_params = [] if conditioning_free: @@ -605,6 +609,15 @@ def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt): return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs']) +def test_tfd(): + clip = torch.randn(2,256,400) + ts = torch.LongTensor([600, 600]) + model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512, + prenet_channels=1024, num_heads=3, permute_codes=True, + input_vec_dim=256, num_layers=12, prenet_layers=4, + dropout=.1) + model(clip, ts, clip) + def test_quant_model(): clip = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) @@ -767,4 +780,4 @@ def extract_diff(in_f, out_f, remove_head=False): if __name__ == '__main__': #extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True) - test_cheater_model() + test_tfd()