diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 60a7674d..e1df8de0 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -225,7 +225,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): cond_right_enc = self.conditioning_encoder(cond_right_full, time_emb) ce = cond_right_enc[:,:,cond_right.shape[-1]-1] cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1) - cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1) + cond = F.interpolate(cond_enc, size=(N,), mode='linear', align_corners=True).permute(0,2,1) return cond def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None): @@ -304,6 +304,28 @@ def test_cheater_model(): print(f'{k}: {prmsz(v)/1000000}') +def test_conditioning_splitting_logic(): + ts = torch.LongTensor([600]) + class fake_conditioner(nn.Module): + def __init__(self): + super().__init__() + def forward(self, t, _): + print(t[:,0]) + return t + model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024, + contraction_dim=512, num_heads=8, num_layers=15, dropout=0, + unconditioned_percentage=.4) + model.conditioning_encoder = fake_conditioner() + BASEDIM=30 + for x in range(BASEDIM+1, BASEDIM+20): + start = random.randint(0,x-BASEDIM) + cl = torch.arange(1, x+1, 1).view(1,1,-1).float().repeat(1,256,1) + print("Effective input: " + str(cl[0, 0, start:BASEDIM+start])) + res = model.process_conditioning(cl, ts, BASEDIM, start, None) + print("Result: " + str(res[0,:,0])) + print() + + def inference_tfdpc5_with_cheater(): with torch.no_grad(): os.makedirs('results/tfdpc_v3', exist_ok=True) @@ -384,5 +406,6 @@ def inference_tfdpc5_with_cheater(): torchaudio.save(f'results/tfdpc_v3/{k}_ref.wav', sample.unsqueeze(0).cpu(), 22050) if __name__ == '__main__': - test_cheater_model() + #test_cheater_model() + test_conditioning_splitting_logic() #inference_tfdpc5_with_cheater() diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 1d8d70da..10ede292 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -97,6 +97,7 @@ class TransformerDiffusion(nn.Module): out_channels=512, # mean and variance num_heads=4, dropout=0, + use_corner_alignment=False, # This is an interpolation parameter only provided for backwards compatibility. ALL NEW TRAINS SHOULD SET THIS TO TRUE. use_fp16=False, new_code_expansion=False, permute_codes=False, @@ -117,6 +118,7 @@ class TransformerDiffusion(nn.Module): self.enable_fp16 = use_fp16 self.new_code_expansion = new_code_expansion self.permute_codes = permute_codes + self.use_corner_alignment = use_corner_alignment self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) @@ -189,7 +191,7 @@ class TransformerDiffusion(nn.Module): def timestep_independent(self, prior, expected_seq_len): if self.new_code_expansion: - prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1) + prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear', align_corners=self.use_corner_alignment).permute(0,2,1) code_emb = self.input_converter(prior) code_emb = self.code_converter(code_emb)