From 15decfdb98c06e80b7e6702c654e26376f74b2e6 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 20 Jul 2022 10:19:02 -0600 Subject: [PATCH] misc --- .../audio/music/transformer_diffusion13.py | 27 +++++------ .../audio/music/transformer_diffusion14.py | 2 +- codes/train.py | 2 +- codes/trainer/eval/music_diffusion_fid.py | 47 +++++++++---------- .../injectors/gaussian_diffusion_injector.py | 4 +- 5 files changed, 39 insertions(+), 43 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 1a7277bd..3fdcdce9 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -4,6 +4,7 @@ from random import randrange import torch import torch.nn as nn import torch.nn.functional as F +import torchvision.utils from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear @@ -38,7 +39,7 @@ class SubBlock(nn.Module): self.mask_initialized = True blk_enc = self.blk_emb_proj(blk_emb) ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask)) - ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x. + ah = ah[:,:,blk_enc.shape[-1]:] # Strip off the blk_emc used for attention and re-align with x. ah = F.gelu(self.attnorm(ah)) h = torch.cat([ah, x], dim=1) hf = self.dropout(checkpoint(self.ff, h)) @@ -168,25 +169,21 @@ class TransformerDiffusion(nn.Module): } return groups - def input_to_random_resolution_and_window(self, x, x_prior): + def input_to_random_resolution_and_window(self, x): """ This function MUST be applied to the target *before* noising. It returns the reduced, re-scoped target as well - as caches an internal prior for the rescoped target which will be useud in training. + as caches an internal prior for the rescoped target which will be used in training. Args: x: Diffusion target - x_prior: Prior input, which is generally just {x} """ - assert x.shape == x_prior.shape, f'{x.shape} {x_prior.shape}' - resolution = randrange(1, self.resolution_steps) + resolution = randrange(0, self.resolution_steps) resolution_scale = 2 ** resolution s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True) - s_prior = F.interpolate(x_prior, scale_factor=1/resolution_scale, mode='linear', align_corners=True) s_diff = s.shape[-1] - self.max_window if s_diff > 1: start = randrange(0, s_diff) s = s[:,:,start:start+self.max_window] - s_prior = x_prior[:,:,start:start+self.max_window] - s_prior = F.interpolate(s_prior, scale_factor=.25, mode='linear', align_corners=True) + s_prior = F.interpolate(s, scale_factor=.25, mode='linear', align_corners=True) s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True) self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device)) return s @@ -196,16 +193,18 @@ class TransformerDiffusion(nn.Module): h = x if resolution is None: + # This is assumed to be training. assert self.preprocessed is not None, 'Preprocessing function not called.' - h = x + assert x_prior is None, 'Provided prior will not be used, instead preprocessing output will be used.' h_sub, resolution = self.preprocessed self.preprocessed = None else: - h_sub = F.interpolate(x_prior, scale_factor=4, mode='linear', align_corners=True) - assert h.shape == h_sub.shape, f'{h.shape} {h_sub.shape}' + assert h.shape[-1] > x_prior.shape[-1] * 3.9, f'{h.shape} {x_prior.shape}' + h_sub = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True) if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) + h_sub = self.unconditioned_prior.repeat(x.shape[0], 1, x.shape[-1]) + code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) else: MIN_COND_LEN = 200 MAX_COND_LEN = 1200 @@ -227,8 +226,8 @@ class TransformerDiffusion(nn.Module): time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) res_emb = self.resolution_embed(resolution) blk_emb = torch.cat([time_emb.unsqueeze(-1), res_emb.unsqueeze(-1), code_emb], dim=-1) - h = torch.cat([h, h_sub], dim=1) + h = torch.cat([h, h_sub], dim=1) h = self.inp_block(h) for layer in self.layers: h = checkpoint(layer, h, blk_emb) diff --git a/codes/models/audio/music/transformer_diffusion14.py b/codes/models/audio/music/transformer_diffusion14.py index 0d163f8a..f1b98180 100644 --- a/codes/models/audio/music/transformer_diffusion14.py +++ b/codes/models/audio/music/transformer_diffusion14.py @@ -24,7 +24,7 @@ class SubBlock(nn.Module): self.ffnorm = nn.GroupNorm(8, contraction_dim) if self.enable_attention_masking: # All regions can attend to the first token, which will be the timestep embedding. Hence, fixed_region. - self.mask = build_local_attention_mask(n=2000, l=48, fixed_region=1) + self.mask = build_local_attention_mask(n=4000, l=48, fixed_region=1) self.mask_initialized = False else: self.mask = None diff --git a/codes/train.py b/codes/train.py index 9376a56a..6c86a7c6 100644 --- a/codes/train.py +++ b/codes/train.py @@ -340,7 +340,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_instrument_quantizer.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_multilevel_sr.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index bf084192..50c03dc1 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -1,28 +1,24 @@ -import functools +import os import os import os.path as osp from glob import glob -from random import shuffle -from time import time import numpy as np import torch +import torch.nn.functional as F import torchaudio import torchvision from pytorch_fid.fid_score import calculate_frechet_distance from torch import distributed from tqdm import tqdm -import torch.nn.functional as F import trainer.eval.evaluator as evaluator from data.audio.unsupervised_audio_dataset import load_audio -from models.audio.mel2vec import ContrastiveTrainingWrapper -from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen from models.clip.contrastive_audio import ContrastiveAudio from models.diffusion.gaussian_diffusion import get_named_beta_schedule from models.diffusion.respace import space_timesteps, SpacedDiffusion from trainer.injectors.audio_injectors import denormalize_torch_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \ - normalize_mel, KmeansQuantizerInjector + KmeansQuantizerInjector, normalize_torch_mel from utils.music_utils import get_music_codegen, get_mel2wav_model, get_cheater_decoder, get_cheater_encoder, \ get_mel2wav_v3_model, get_ar_prior from utils.util import opt_get, load_model_from_config @@ -117,7 +113,7 @@ class MusicDiffusionFid(evaluator.Evaluator): model_kwargs={'codes': mel}) gen = pixel_shuffle_1d(gen, self.squeeze_ratio) - return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate + return gen, real_resampled, normalize_torch_mel(self.spec_fn({'in': gen})['out']), normalize_torch_mel(mel), sample_rate def perform_diffusion_from_codes(self, audio, sample_rate=22050): real_resampled = audio @@ -126,7 +122,7 @@ class MusicDiffusionFid(evaluator.Evaluator): mel = self.spec_fn({'in': audio})['out'] codegen = self.local_modules['codegen'].to(mel.device) codes = codegen.get_codes(mel, project=True) - mel_norm = normalize_mel(mel) + mel_norm = normalize_torch_mel(mel) gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, model_kwargs={'codes': codes, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])}) @@ -140,27 +136,27 @@ class MusicDiffusionFid(evaluator.Evaluator): return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050): - real_resampled = audio audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] - mel_norm = normalize_mel(mel) + mel_norm = normalize_torch_mel(mel) #def denoising_fn(x): # q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1) # s = q9.clamp(1, 9999999999) # x = x.clamp(-s, s) / s # return x - gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, #denoised_fn=denoising_fn, clip_denoised=False, - model_kwargs={'truth_mel': mel_norm}) + sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop + gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm}) gen_mel_denorm = denormalize_torch_mel(gen_mel) output_shape = (1,16,audio.shape[-1]//16) self.spec_decoder = self.spec_decoder.to(audio.device) - gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape, + sampler = self.spectral_diffuser.ddim_sample_loop if self.ddim else self.spectral_diffuser.p_sample_loop + gen_wav = sampler(self.spec_decoder, output_shape, model_kwargs={'aligned_conditioning': gen_mel_denorm}) gen_wav = pixel_shuffle_1d(gen_wav, 16) - real_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape, + real_wav = sampler(self.spec_decoder, output_shape, model_kwargs={'aligned_conditioning': mel}) real_wav = pixel_shuffle_1d(real_wav, 16) @@ -170,7 +166,7 @@ class MusicDiffusionFid(evaluator.Evaluator): audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] - mel_norm = normalize_mel(mel) + mel_norm = normalize_torch_mel(mel) cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm) # 1. Generate the cheater latent using the input as a reference. @@ -203,7 +199,7 @@ class MusicDiffusionFid(evaluator.Evaluator): audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] - mel_norm = normalize_mel(mel) + mel_norm = normalize_torch_mel(mel) cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm) cheater_codes = self.kmeans_inj({'in': cheater})['out'] ar_latent = self.local_modules['ar_prior'].to(audio.device)(cheater_codes, cheater, return_latent=True) @@ -233,16 +229,17 @@ class MusicDiffusionFid(evaluator.Evaluator): def perform_chained_sr(self, audio, sample_rate=22050): audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] - mel_norm = normalize_mel(mel) + mel_norm = normalize_torch_mel(mel) + #mel_norm = mel_norm[:,:,:448*4] # restricts first stage to optimal training window. conditioning = mel_norm[:,:,:1200] downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='linear', align_corners=True) - sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop stage1_shape = (1, 256, downsampled.shape[-1]*4) + sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop # Chain super-sampling using 2 stages. - stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([2], device=audio.device), + stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device), 'x_prior': downsampled, 'conditioning_input': conditioning}) - stage2 = sampler(self.model, audio.shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device), + stage2 = sampler(self.model, mel.shape, model_kwargs={'resolution': torch.tensor([0], device=audio.device), 'x_prior': stage1, 'conditioning_input': conditioning}) # Decode into waveform. @@ -328,17 +325,17 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator', also_load_savepoint=False, strict_load=False, - load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\12000_generator_fixed.pth' + load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\22000_generator.pth' ).cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. - 'diffusion_steps': 64, # basis: 192 - 'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': False, + 'diffusion_steps': 128, # basis: 192 + 'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False, 'diffusion_schedule': 'linear', 'diffusion_type': 'chained_sr', #'causal': True, 'causal_slope': 4, #'partial_low': 128, 'partial_high': 192 } - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 10, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) fds = [] for i in range(2): diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index b2e9222a..ca364d90 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -80,7 +80,7 @@ class GaussianDiffusionInjector(Injector): def forward(self, state): gen = self.env['generators'][self.opt['generator']] hq = state[self.input] - assert hq.max() < 1.000001 or hq.min() > -1.00001, f"Attempting to train gaussian diffusion on un-normalized inputs. This won't work, silly! {hq.min()} {hq.max()}" + assert hq.max() < 1.000001 or hq.min() > -1.00001, f"Attempting to train gaussian diffusion on un-normalized inputs. This won't work, silly! {hq.min()} {hq.max()}" with autocast(enabled=self.env['opt']['fp16']): if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0): @@ -90,7 +90,7 @@ class GaussianDiffusionInjector(Injector): self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically. model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()} if self.preprocess_fn is not None: - hq = getattr(gen.module, self.preprocess_fn)(hq, **model_inputs) + hq = getattr(gen.module, self.preprocess_fn)(hq) t, weights = sampler.sample(hq.shape[0], hq.device) if self.causal_mode: