diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index e1b748ca..208b4c6f 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -100,6 +100,7 @@ class TransformerDiffusion(nn.Module): use_fp16=False, ar_prior=False, new_code_expansion=False, + permute_codes=False, # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. # Parameters for re-training head @@ -116,6 +117,7 @@ class TransformerDiffusion(nn.Module): self.unconditioned_percentage = unconditioned_percentage self.enable_fp16 = use_fp16 self.new_code_expansion = new_code_expansion + self.permute_codes = permute_codes self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) @@ -228,6 +230,8 @@ class TransformerDiffusion(nn.Module): def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False): if precomputed_code_embeddings is not None: assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." + if self.permute_codes: + codes = codes.permute(0,2,1) unused_params = [] if conditioning_free: @@ -605,6 +609,15 @@ def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt): return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs']) +def test_tfd(): + clip = torch.randn(2,256,400) + ts = torch.LongTensor([600, 600]) + model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512, + prenet_channels=1024, num_heads=3, permute_codes=True, + input_vec_dim=256, num_layers=12, prenet_layers=4, + dropout=.1) + model(clip, ts, clip) + def test_quant_model(): clip = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) @@ -767,4 +780,4 @@ def extract_diff(in_f, out_f, remove_head=False): if __name__ == '__main__': #extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True) - test_cheater_model() + test_tfd()