From acfe9cf880ff5228fe782ac3ecbe3ebeaf5ceb45 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 10 Jun 2022 22:39:15 -0600 Subject: [PATCH] fp16 --- .../audio/music/transformer_diffusion8.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion8.py b/codes/models/audio/music/transformer_diffusion8.py index 973c83ca..5791be02 100644 --- a/codes/models/audio/music/transformer_diffusion8.py +++ b/codes/models/audio/music/transformer_diffusion8.py @@ -175,13 +175,14 @@ class TransformerDiffusion(nn.Module): code_emb = self.timestep_independent(codes, x.shape[-1]) unused_params.append(self.unconditioned_embedding) - blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels)) - x = self.inp_block(x).permute(0,2,1) + with torch.autocast(x.device.type, enabled=self.enable_fp16): + blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels)) + x = self.inp_block(x).permute(0,2,1) - rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) - x = self.intg(torch.cat([x, code_emb], dim=-1)) - for layer in self.layers: - x = checkpoint(layer, x, blk_emb, rotary_pos_emb) + rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) + x = self.intg(torch.cat([x, code_emb], dim=-1)) + for layer in self.layers: + x = checkpoint(layer, x, blk_emb, rotary_pos_emb) x = x.float().permute(0,2,1) out = self.out(x) @@ -318,9 +319,9 @@ def test_quant_model(): clip = torch.randn(2, 256, 400) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=2048, block_channels=1024, - prenet_channels=1024, num_heads=8, - input_vec_dim=1024, num_layers=16, prenet_layers=6) + model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=3072, block_channels=1536, + prenet_channels=1024, num_heads=12, + input_vec_dim=1024, num_layers=24, prenet_layers=6) model.get_grad_norm_parameter_groups() quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth') @@ -335,8 +336,8 @@ def test_ar_model(): clip = torch.randn(2, 256, 400) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1024, - input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True, + model = TransformerDiffusionWithARPrior(model_channels=3072, block_channels=1536, prenet_channels=1536, + input_vec_dim=512, num_layers=24, prenet_layers=6, freeze_diff=True, unconditioned_percentage=.4) model.get_grad_norm_parameter_groups()