This commit is contained in:
James Betker 2022-06-10 22:39:15 -06:00
parent aca9024d9b
commit acfe9cf880

View File

@ -175,13 +175,14 @@ class TransformerDiffusion(nn.Module):
code_emb = self.timestep_independent(codes, x.shape[-1])
unused_params.append(self.unconditioned_embedding)
blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels))
x = self.inp_block(x).permute(0,2,1)
with torch.autocast(x.device.type, enabled=self.enable_fp16):
blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels))
x = self.inp_block(x).permute(0,2,1)
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
x = self.intg(torch.cat([x, code_emb], dim=-1))
for layer in self.layers:
x = checkpoint(layer, x, blk_emb, rotary_pos_emb)
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
x = self.intg(torch.cat([x, code_emb], dim=-1))
for layer in self.layers:
x = checkpoint(layer, x, blk_emb, rotary_pos_emb)
x = x.float().permute(0,2,1)
out = self.out(x)
@ -318,9 +319,9 @@ def test_quant_model():
clip = torch.randn(2, 256, 400)
cond = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=2048, block_channels=1024,
prenet_channels=1024, num_heads=8,
input_vec_dim=1024, num_layers=16, prenet_layers=6)
model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=3072, block_channels=1536,
prenet_channels=1024, num_heads=12,
input_vec_dim=1024, num_layers=24, prenet_layers=6)
model.get_grad_norm_parameter_groups()
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
@ -335,8 +336,8 @@ def test_ar_model():
clip = torch.randn(2, 256, 400)
cond = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1024,
input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True,
model = TransformerDiffusionWithARPrior(model_channels=3072, block_channels=1536, prenet_channels=1536,
input_vec_dim=512, num_layers=24, prenet_layers=6, freeze_diff=True,
unconditioned_percentage=.4)
model.get_grad_norm_parameter_groups()