diff --git a/codes/models/audio/music/transformer_diffusion14.py b/codes/models/audio/music/transformer_diffusion14.py index 9143409e..0d163f8a 100644 --- a/codes/models/audio/music/transformer_diffusion14.py +++ b/codes/models/audio/music/transformer_diffusion14.py @@ -149,7 +149,7 @@ class TransformerDiffusion(nn.Module): def forward(self, x, timesteps, prior=None, conditioning_free=False): if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) + code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) else: code_emb = self.input_converter(prior) @@ -260,6 +260,7 @@ def test_cheater_model(): print_network(model) o = model(clip, ts, clip) + o = model(clip, ts, clip, conditioning_free=True) pg = model.get_grad_norm_parameter_groups() @@ -274,6 +275,6 @@ def extract_cheater_encoder(in_f, out_f): if __name__ == '__main__': #test_local_attention_mask() - extract_cheater_encoder('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\104500_generator_ema.pth', 'X:\\dlas\\experiments\\tfd12_self_learned_cheater_enc.pth', True) + #extract_cheater_encoder('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\104500_generator_ema.pth', 'X:\\dlas\\experiments\\tfd12_self_learned_cheater_enc.pth', True) test_cheater_model() #extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)