From 195388712266dabd9619796e045dbfbc593560f3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 1 Jul 2022 00:44:40 -0600 Subject: [PATCH] Add conditoning_masking to tfdpcv5 --- codes/models/audio/music/tfdpc_v5.py | 22 +++++++++++++++++----- codes/train.py | 2 +- codes/trainer/eval/music_diffusion_fid.py | 12 ++++++------ 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 09230b99..5115e773 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -1,5 +1,6 @@ import itertools import os +import random import torch import torch.nn as nn @@ -127,6 +128,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): use_fp16=False, # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + conditioning_masking=0, ): super().__init__() @@ -136,6 +138,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): self.out_channels = out_channels self.dropout = dropout self.unconditioned_percentage = unconditioned_percentage + self.conditioning_masking = conditioning_masking self.enable_fp16 = use_fp16 self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1) @@ -192,7 +195,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): } return groups - def forward(self, x, timesteps, conditioning_input, conditioning_free=False, cond_start=0): + def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None): unused_params = [] time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) @@ -201,9 +204,18 @@ class TransformerDiffusionWithPointConditioning(nn.Module): cond = self.unconditioned_embedding cond = cond.repeat(1,x.shape[-1],1) else: - cond_enc = self.conditioning_encoder(conditioning_input, time_emb) - cs = cond_enc[:,:,cond_start] - ce = cond_enc[:,:,x.shape[-1]+cond_start] + if custom_conditioning_fetcher is not None: + cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb) + else: + if self.conditioning_masking > 0: + cond_op_len = x.shape[-1] + mask_len = int(cond_op_len * self.conditioning_masking) + if mask_len > 0: + start = random.randint(0, (cond_op_len-mask_len)) + cond_start + conditioning_input[:,:,start:(start+mask_len)] = 0 + cond_enc = self.conditioning_encoder(conditioning_input, time_emb) + cs = cond_enc[:,:,cond_start] + ce = cond_enc[:,:,x.shape[-1]+cond_start] cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1) cond = F.interpolate(cond_enc, size=(x.shape[-1],), mode='linear').permute(0,2,1) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. @@ -255,7 +267,7 @@ def test_cheater_model(): # For music: model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024, contraction_dim=512, num_heads=8, num_layers=15, dropout=0, - unconditioned_percentage=.4) + unconditioned_percentage=.4, conditioning_masking=.5) print_network(model) o = model(clip, ts, cl) pg = model.get_grad_norm_parameter_groups() diff --git a/codes/train.py b/codes/train.py index 42027511..f46318aa 100644 --- a/codes/train.py +++ b/codes/train.py @@ -339,7 +339,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_cheater_gen.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_ar_cheater_gen.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 982ba81e..a94ab7a0 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -306,15 +306,15 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen_r8.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5\\models\\46000_generator_ema.pth' + load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5\\models\\71000_generator_ema.pth' ).cuda() - opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) - #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. - 'diffusion_steps': 32, - 'conditioning_free': True, 'conditioning_free_k': 1, 'clip_audio': False, 'use_ddim': True, + opt_eval = {#'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) + 'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. + 'diffusion_steps': 128, + 'conditioning_free': True, 'conditioning_free_k': 2, 'clip_audio': False, 'use_ddim': True, 'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen', #'partial_low': 128, 'partial_high': 192 } - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 200, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 225, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval())