From b2a83efe50bedead865fd4068410a44d20ac3add Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 1 Jun 2022 16:35:15 -0600 Subject: [PATCH] a few fixes --- .../audio/music/transformer_diffusion5.py | 11 +++++++--- .../audio/music/transformer_diffusion7.py | 8 +++++--- codes/trainer/eval/music_diffusion_fid.py | 20 ++++++++++++------- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion5.py b/codes/models/audio/music/transformer_diffusion5.py index e9ef9a23..1fc4053b 100644 --- a/codes/models/audio/music/transformer_diffusion5.py +++ b/codes/models/audio/music/transformer_diffusion5.py @@ -227,12 +227,17 @@ class TransformerDiffusionWithQuantizer(nn.Module): self.m2v.min_gumbel_temperature, ) - def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False): + def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False): proj = self.m2v.m2v.input_blocks(truth_mel).permute(0,2,1) proj = self.m2v.m2v.projector.layer_norm(proj) - vectors, _, probs = self.m2v.quantizer(proj, return_probs=True) + vectors, perplexity, probs = self.m2v.quantizer(proj, return_probs=True) + diversity = (self.m2v.quantizer.num_codevectors - perplexity) / self.m2v.quantizer.num_codevectors self.log_codes(probs) - return self.diff(x, timesteps, codes=vectors, conditioning_input=conditioning_input, conditioning_free=conditioning_free) + diff = self.diff(x, timesteps, codes=vectors, conditioning_input=conditioning_input, conditioning_free=conditioning_free) + if disable_diversity: + return diff + else: + return diff, diversity def log_codes(self, codes): if self.internal_step % 5 == 0: diff --git a/codes/models/audio/music/transformer_diffusion7.py b/codes/models/audio/music/transformer_diffusion7.py index f8102690..865ea8c2 100644 --- a/codes/models/audio/music/transformer_diffusion7.py +++ b/codes/models/audio/music/transformer_diffusion7.py @@ -217,7 +217,7 @@ class TransformerDiffusionWithQuantizer(nn.Module): self.m2v.min_gumbel_temperature, ) - def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False): + def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False): quant_grad_enabled = self.internal_step > self.freeze_quantizer_until with torch.set_grad_enabled(quant_grad_enabled): proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True) @@ -231,8 +231,10 @@ class TransformerDiffusionWithQuantizer(nn.Module): proj = proj + unused diversity_loss = diversity_loss * 0 - return self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, - conditioning_free=conditioning_free), diversity_loss + diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) + if disable_diversity: + return diff + return diff, diversity_loss def get_debug_values(self, step, __): if self.m2v.total_codes > 0: diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 5518cacf..7d9c8dd7 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -114,9 +114,15 @@ class MusicDiffusionFid(evaluator.Evaluator): mel = self.spec_fn({'in': audio})['out'] mel_norm = normalize_mel(mel) - gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, + def denoising_fn(x): + q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1) + s = q9.clamp(1, 9999999999) + x = x.clamp(-s, s) / s + return x + gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, denoised_fn=denoising_fn, clip_denoised=False, model_kwargs={'truth_mel': mel, - 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])}) + 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]), + 'disable_diversity': True}) gen_mel_denorm = denormalize_mel(gen_mel) output_shape = (1,16,audio.shape[-1]//16) @@ -163,7 +169,7 @@ class MusicDiffusionFid(evaluator.Evaluator): gen_projections = [] real_projections = [] for i in tqdm(list(range(0, len(self.data), self.skip))): - path = self.data[i + self.env['rank']] + path = self.data[(i + self.env['rank']) % len(self.data)] audio = load_audio(path, 22050).to(self.dev) audio = audio[:, :22050*10] sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio) @@ -195,13 +201,13 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': - diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_quant.yml', 'generator', + diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\train.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\music_tfd5_with_quantizer_basis.pth' + load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\models\\27000_generator_ema.pth' ).cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100, - 'conditioning_free': False, 'conditioning_free_k': 2, + 'conditioning_free': True, 'conditioning_free_k': 1, 'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant'} - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 558, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 560, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval())