diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 97a78172..9e7d2e29 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -215,6 +215,15 @@ class TransformerDiffusionWithPointConditioning(nn.Module): left_pt = -1 cond_right = conditioning_input[:,:,cond_start+N:] right_pt = 0 + + if self.training: + # Arbitrarily restrict the context given. We should support short contexts and without this they are never encountered. + arb_context_cap = random.randint(50, 100) + if cond_left.shape[-1] > arb_context_cap and random() > .5: + cond_left = cond_left[:,:,-arb_context_cap:] + if cond_right.shape[-1] > arb_context_cap and random() > .5: + cond_right = cond_right[:,:,:arb_context_cap] + elif cond_left is None: assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}' cond_pre = conditioning_input[:,:,:cond_start] diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index b5e010c1..9dfc5986 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -230,10 +230,11 @@ class MusicDiffusionFid(evaluator.Evaluator): # 1. Generate the cheater latent using the input as a reference. sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop - output_shape = (1, 256, cheater.shape[-1]-80) - gen_cheater = sampler(self.model, output_shape, progress=True, + # center-pad the conditioning input (the center isn't actually used). this is hack for giving tfdpc5 a bigger working context. + cheater_padded = torch.cat([cheater[:,:,cheater.shape[-1]//2:], torch.zeros(1,256,160, device=cheater.device), cheater[:,:,:cheater.shape[-1]//2]], dim=-1) + gen_cheater = sampler(self.model, cheater.shape, progress=True, causal=self.causal, causal_slope=self.causal_slope, - model_kwargs={'conditioning_input': cheater, 'cond_start': 40}) + model_kwargs={'conditioning_input': cheater_padded, 'cond_start': 80}) # 2. Decode the cheater into a MEL gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,gen_cheater.shape[-1]*16), progress=True,