diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 07862598..7efe1a3b 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -114,9 +114,6 @@ class ConditioningEncoder(nn.Module): class TransformerDiffusionWithPointConditioning(nn.Module): - """ - A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? - """ def __init__( self, in_channels=256, @@ -129,9 +126,11 @@ class TransformerDiffusionWithPointConditioning(nn.Module): input_cond_dim=1024, num_heads=8, dropout=0, - time_proj=False, + time_proj=True, + new_cond=False, use_fp16=False, checkpoint_conditioning=True, # This will need to be false for DDP training. :( + regularization=False, # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. ): @@ -144,6 +143,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module): self.dropout = dropout self.unconditioned_percentage = unconditioned_percentage self.enable_fp16 = use_fp16 + self.regularization = regularization + self.new_cond = new_cond self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1) self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning, time_proj=time_proj) @@ -166,13 +167,11 @@ class TransformerDiffusionWithPointConditioning(nn.Module): cond_projection=(k % 3 == 0), use_conv=(k % 3 != 0), ) for k in range(num_layers)]) - self.out = nn.Sequential( normalization(model_channels), nn.SiLU(), zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), ) - self.debug_codes = {} def get_grad_norm_parameter_groups(self): @@ -199,10 +198,24 @@ class TransformerDiffusionWithPointConditioning(nn.Module): } return groups - def process_conditioning(self, conditioning_input, time_emb, N, cond_start, custom_conditioning_fetcher): - if custom_conditioning_fetcher is not None: - cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb) - else: + def process_conditioning(self, conditioning_input, time_emb, N, cond_start, cond_left, cond_right): + if self.training and self.regularization: + # frequency regularization + fstart = random.randint(0, conditioning_input.shape[1] - 1) + fclip = random.randint(1, min(conditioning_input.shape[1]-fstart, 16)) + conditioning_input[:,fstart:fstart+fclip] = 0 + # time regularization + for k in range(1, random.randint(2, 4)): + tstart = random.randint(0, conditioning_input.shape[-1] - 1) + tclip = random.randint(1, min(conditioning_input.shape[-1]-tstart, 10)) + conditioning_input[:,:,tstart:tstart+tclip] = 0 + + if cond_left is None and self.new_cond: + cond_left = conditioning_input[:,:,:max(cond_start, 20)] + left_pt = cond_start + cond_right = conditioning_input[:,:,min(N+cond_start, conditioning_input.shape[-1]-20):] + right_pt = cond_right.shape[-1] - (conditioning_input.shape[-1] - (N+cond_start)) + elif cond_left is None: assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}' cond_pre = conditioning_input[:,:,:cond_start] cond_aligned = conditioning_input[:,:,cond_start:N+cond_start] @@ -223,19 +236,25 @@ class TransformerDiffusionWithPointConditioning(nn.Module): cond_right = cond_right[:,:,to_remove_right:] # Concatenate the _pre and _post back on. - cond_left_full = torch.cat([cond_pre, cond_left], dim=-1) - cond_right_full = torch.cat([cond_right, cond_post], dim=-1) + left_pt = cond_start + right_pt = cond_right.shape[-1] + cond_left = torch.cat([cond_pre, cond_left], dim=-1) + cond_right = torch.cat([cond_right, cond_post], dim=-1) + else: + left_pt = -1 + right_pt = 0 + + # Propagate through the encoder. + cond_left_enc = self.conditioning_encoder(cond_left, time_emb) + cs = cond_left_enc[:,:,left_pt] + cond_right_enc = self.conditioning_encoder(cond_right, time_emb) + ce = cond_right_enc[:,:,right_pt] - # Propagate through the encoder. - cond_left_enc = self.conditioning_encoder(cond_left_full, time_emb) - cs = cond_left_enc[:,:,cond_start] - cond_right_enc = self.conditioning_encoder(cond_right_full, time_emb) - ce = cond_right_enc[:,:,cond_right.shape[-1]-1] cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1) cond = F.interpolate(cond_enc, size=(N,), mode='linear', align_corners=True).permute(0,2,1) return cond - def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None): + def forward(self, x, timesteps, conditioning_input=None, cond_left=None, cond_right=None, conditioning_free=False, cond_start=0): unused_params = [] time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) @@ -244,7 +263,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): cond = self.unconditioned_embedding cond = cond.repeat(1,x.shape[-1],1) else: - cond = self.process_conditioning(conditioning_input, time_emb, x.shape[-1], cond_start, custom_conditioning_fetcher) + cond = self.process_conditioning(conditioning_input, time_emb, x.shape[-1], cond_start, cond_left, cond_right) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: unconditioned_batches = torch.rand((cond.shape[0], 1, 1), @@ -276,9 +295,10 @@ class TransformerDiffusionWithPointConditioning(nn.Module): # Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes # higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than # directly fiddling with the gradients. - for p in scaled_grad_parameters: - if hasattr(p, 'grad') and p.grad is not None: - p.grad *= .2 + if not self.new_cond: # Not really related, I just don't want to add a new config. + for p in scaled_grad_parameters: + if hasattr(p, 'grad') and p.grad is not None: + p.grad *= .2 @register_model @@ -293,11 +313,12 @@ def test_cheater_model(): # For music: model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024, - contraction_dim=512, num_heads=8, num_layers=15, dropout=0, - unconditioned_percentage=.4, checkpoint_conditioning=False) + contraction_dim=512, num_heads=8, num_layers=15, dropout=0, + unconditioned_percentage=.4, checkpoint_conditioning=False, + regularization=True, new_cond=True) print_network(model) - for k in range(100): - o = model(clip, ts, cl) + for cs in range(276,cl.shape[-1]-clip.shape[-1]): + o = model(clip, ts, cl, cond_start=cs) pg = model.get_grad_norm_parameter_groups() def prmsz(lp): sz = 0 diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 385c790f..ed0a8cf9 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -617,8 +617,6 @@ class GaussianDiffusion: mask, noise=None, clip_denoised=True, - causal=False, - causal_slope=1, denoised_fn=None, cond_fn=None, model_kwargs=None, @@ -640,8 +638,6 @@ class GaussianDiffusion: img, t, clip_denoised=clip_denoised, - causal=causal, - causal_slope=causal_slope, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 53eb9d51..b29309be 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -436,18 +436,18 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\53000_generator_ema.pth' + load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\80500_generator_ema.pth' ).cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. - 'diffusion_steps': 220, # basis: 192 + 'diffusion_steps': 256, # basis: 192 'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False, 'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen', # Slope 1: 1.03x, 2: 1.06, 4: 1.135, 8: 1.27, 16: 1.54 - 'causal': True, 'causal_slope': 3, # DONT FORGET TO INCREMENT THE STEP! + 'causal': True, 'causal_slope': 4, # DONT FORGET TO INCREMENT THE STEP! #'partial_low': 128, 'partial_high': 192 } - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 3, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 104, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) fds = [] for i in range(2):