diff --git a/codes/models/audio/music/transformer_diffusion8.py b/codes/models/audio/music/transformer_diffusion8.py index d22130ad..2a6a9872 100644 --- a/codes/models/audio/music/transformer_diffusion8.py +++ b/codes/models/audio/music/transformer_diffusion8.py @@ -149,6 +149,7 @@ class TransformerDiffusion(nn.Module): def timestep_independent(self, prior, expected_seq_len): code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior) + code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: @@ -156,7 +157,6 @@ class TransformerDiffusion(nn.Module): device=code_emb.device) < self.unconditioned_percentage code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1), code_emb) - code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) return expanded_code_emb @@ -215,7 +215,7 @@ class TransformerDiffusionWithQuantizer(nn.Module): self.quantizer.min_gumbel_temperature, ) - def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False): + def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): quant_grad_enabled = self.internal_step > self.freeze_quantizer_until with torch.set_grad_enabled(quant_grad_enabled): proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True) @@ -336,7 +336,8 @@ def test_ar_model(): cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1024, - input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True) + input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True, + unconditioned_percentage=.4) model.get_grad_norm_parameter_groups() ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth') diff --git a/codes/train.py b/codes/train.py index bc24da41..d1a5b3cf 100644 --- a/codes/train.py +++ b/codes/train.py @@ -339,7 +339,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt_upper.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_diffusion_tfd.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)