From ff5c03b460e17e84d97b4d443e537e8c4c233297 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 15 Jun 2022 08:58:02 -0600 Subject: [PATCH] tfd12 with ar prior --- .../audio/music/transformer_diffusion12.py | 35 ++++++++++++++++--- .../audio/tts/unet_diffusion_tts_flat.py | 18 ++++++++-- codes/train.py | 2 +- 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 128f5168..0abcfacb 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -1,4 +1,5 @@ import itertools +from time import time import torch import torch.nn as nn @@ -99,6 +100,8 @@ class TransformerDiffusion(nn.Module): ar_prior=False, # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + # Parameters for re-training head + freeze_except_code_converters=False, ): super().__init__() @@ -161,6 +164,16 @@ class TransformerDiffusion(nn.Module): zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), ) + if freeze_except_code_converters: + for p in self.parameters(): + p.DO_NOT_TRAIN = True + p.requires_grad = False + for m in [self.input_converter and self.code_converter]: + for p in m.parameters(): + del p.DO_NOT_TRAIN + p.requires_grad = True + + self.debug_codes = {} def get_grad_norm_parameter_groups(self): @@ -391,7 +404,7 @@ class TransformerDiffusionWithPretrainedVqvae(nn.Module): 'out': list(self.diff.out.parameters()), 'x_proj': list(self.diff.inp_block.parameters()), 'layers': list(self.diff.layers.parameters()), - 'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()), + #'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()), 'time_embed': list(self.diff.time_embed.parameters()), } return groups @@ -534,7 +547,7 @@ def test_vqvae_model(): model = TransformerDiffusionWithPretrainedVqvae(in_channels=100, out_channels=200, model_channels=1024, contraction_dim=512, prenet_channels=1024, num_heads=8, - input_vec_dim=512, num_layers=12, prenet_layers=6, + input_vec_dim=512, num_layers=12, prenet_layers=6, ar_prior=True, dropout=.1, vqargs= { 'positional_dims': 1, 'channels': 80, 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192, @@ -549,6 +562,20 @@ def test_vqvae_model(): o = model(clip, ts, cond) pg = model.get_grad_norm_parameter_groups() + """ + with torch.no_grad(): + proj = torch.randn(2, 100, 512).cuda() + clip = clip.cuda() + ts = ts.cuda() + start = time() + model = model.cuda().eval() + model.diff.enable_fp16 = True + ti = model.diff.timestep_independent(proj, clip.shape[2]) + for k in range(100): + model.diff(clip, ts, precomputed_code_embeddings=ti) + print(f"Elapsed: {time()-start}") + """ + def test_multi_vqvae_model(): clip = torch.randn(2, 256, 400) @@ -556,7 +583,7 @@ def test_multi_vqvae_model(): ts = torch.LongTensor([600, 600]) # For music: - model = TransformerDiffusionWithMultiPretrainedVqvae(in_channels=256, out_channels=200, + model = TransformerDiffusionWithMultiPretrainedVqvae(in_channels=256, out_channels=512, model_channels=1024, contraction_dim=512, prenet_channels=1024, num_heads=8, input_vec_dim=2048, num_layers=12, prenet_layers=6, @@ -604,4 +631,4 @@ def test_ar_model(): if __name__ == '__main__': - test_multi_vqvae_model() + test_vqvae_model() diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat.py b/codes/models/audio/tts/unet_diffusion_tts_flat.py index ce0ef9e5..ca02171f 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat.py @@ -1,4 +1,5 @@ import random +from time import time import torch import torch.nn as nn @@ -320,9 +321,22 @@ if __name__ == '__main__': aligned_sequence = torch.randint(0,8192,(2,100)) cond = torch.randn(2, 100, 400) ts = torch.LongTensor([600, 600]) - model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5, freeze_everything_except_autoregressive_inputs=True) + model = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, + in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=True, num_heads=16, + layer_drop=0, unconditioned_percentage=0) # Test with latent aligned conditioning #o = model(clip, ts, aligned_latent, cond) # Test with sequence aligned conditioning - o = model(clip, ts, aligned_sequence, cond) + #o = model(clip, ts, aligned_sequence, cond) + + with torch.no_grad(): + proj = torch.randn(2, 100, 1024).cuda() + clip = clip.cuda() + ts = ts.cuda() + start = time() + model = model.cuda().eval() + ti = model.timestep_independent(proj, clip, clip.shape[2], False) + for k in range(100): + model(clip, ts, precomputed_aligned_embeddings=ti) + print(f"Elapsed: {time()-start}") diff --git a/codes/train.py b/codes/train.py index d1a5b3cf..3c242f36 100644 --- a/codes/train.py +++ b/codes/train.py @@ -339,7 +339,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_diffusion_tfd.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_unified_alignment.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)