From 48270272e78392147bfad32e3f61c05d1492a902 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 6 Jul 2022 16:45:03 -0600 Subject: [PATCH] Use corner alignment for linear interpolation in TFDPC and TFD12 I noticed from experimentation that when this is not enabled, the interpolation edges are "sticky", which is to say there is more variance in the center of the interpolation than at the edges. --- codes/models/audio/music/tfdpc_v5.py | 27 +++++++++++++++++-- .../audio/music/transformer_diffusion12.py | 4 ++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 60a7674d..e1df8de0 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -225,7 +225,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): cond_right_enc = self.conditioning_encoder(cond_right_full, time_emb) ce = cond_right_enc[:,:,cond_right.shape[-1]-1] cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1) - cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1) + cond = F.interpolate(cond_enc, size=(N,), mode='linear', align_corners=True).permute(0,2,1) return cond def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None): @@ -304,6 +304,28 @@ def test_cheater_model(): print(f'{k}: {prmsz(v)/1000000}') +def test_conditioning_splitting_logic(): + ts = torch.LongTensor([600]) + class fake_conditioner(nn.Module): + def __init__(self): + super().__init__() + def forward(self, t, _): + print(t[:,0]) + return t + model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024, + contraction_dim=512, num_heads=8, num_layers=15, dropout=0, + unconditioned_percentage=.4) + model.conditioning_encoder = fake_conditioner() + BASEDIM=30 + for x in range(BASEDIM+1, BASEDIM+20): + start = random.randint(0,x-BASEDIM) + cl = torch.arange(1, x+1, 1).view(1,1,-1).float().repeat(1,256,1) + print("Effective input: " + str(cl[0, 0, start:BASEDIM+start])) + res = model.process_conditioning(cl, ts, BASEDIM, start, None) + print("Result: " + str(res[0,:,0])) + print() + + def inference_tfdpc5_with_cheater(): with torch.no_grad(): os.makedirs('results/tfdpc_v3', exist_ok=True) @@ -384,5 +406,6 @@ def inference_tfdpc5_with_cheater(): torchaudio.save(f'results/tfdpc_v3/{k}_ref.wav', sample.unsqueeze(0).cpu(), 22050) if __name__ == '__main__': - test_cheater_model() + #test_cheater_model() + test_conditioning_splitting_logic() #inference_tfdpc5_with_cheater() diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 1d8d70da..10ede292 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -97,6 +97,7 @@ class TransformerDiffusion(nn.Module): out_channels=512, # mean and variance num_heads=4, dropout=0, + use_corner_alignment=False, # This is an interpolation parameter only provided for backwards compatibility. ALL NEW TRAINS SHOULD SET THIS TO TRUE. use_fp16=False, new_code_expansion=False, permute_codes=False, @@ -117,6 +118,7 @@ class TransformerDiffusion(nn.Module): self.enable_fp16 = use_fp16 self.new_code_expansion = new_code_expansion self.permute_codes = permute_codes + self.use_corner_alignment = use_corner_alignment self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) @@ -189,7 +191,7 @@ class TransformerDiffusion(nn.Module): def timestep_independent(self, prior, expected_seq_len): if self.new_code_expansion: - prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1) + prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear', align_corners=self.use_corner_alignment).permute(0,2,1) code_emb = self.input_converter(prior) code_emb = self.code_converter(code_emb)