fp16
This commit is contained in:
parent
aca9024d9b
commit
acfe9cf880
|
@ -175,13 +175,14 @@ class TransformerDiffusion(nn.Module):
|
|||
code_emb = self.timestep_independent(codes, x.shape[-1])
|
||||
unused_params.append(self.unconditioned_embedding)
|
||||
|
||||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels))
|
||||
x = self.inp_block(x).permute(0,2,1)
|
||||
with torch.autocast(x.device.type, enabled=self.enable_fp16):
|
||||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels))
|
||||
x = self.inp_block(x).permute(0,2,1)
|
||||
|
||||
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
||||
x = self.intg(torch.cat([x, code_emb], dim=-1))
|
||||
for layer in self.layers:
|
||||
x = checkpoint(layer, x, blk_emb, rotary_pos_emb)
|
||||
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
||||
x = self.intg(torch.cat([x, code_emb], dim=-1))
|
||||
for layer in self.layers:
|
||||
x = checkpoint(layer, x, blk_emb, rotary_pos_emb)
|
||||
|
||||
x = x.float().permute(0,2,1)
|
||||
out = self.out(x)
|
||||
|
@ -318,9 +319,9 @@ def test_quant_model():
|
|||
clip = torch.randn(2, 256, 400)
|
||||
cond = torch.randn(2, 256, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=2048, block_channels=1024,
|
||||
prenet_channels=1024, num_heads=8,
|
||||
input_vec_dim=1024, num_layers=16, prenet_layers=6)
|
||||
model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=3072, block_channels=1536,
|
||||
prenet_channels=1024, num_heads=12,
|
||||
input_vec_dim=1024, num_layers=24, prenet_layers=6)
|
||||
model.get_grad_norm_parameter_groups()
|
||||
|
||||
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
|
||||
|
@ -335,8 +336,8 @@ def test_ar_model():
|
|||
clip = torch.randn(2, 256, 400)
|
||||
cond = torch.randn(2, 256, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1024,
|
||||
input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True,
|
||||
model = TransformerDiffusionWithARPrior(model_channels=3072, block_channels=1536, prenet_channels=1536,
|
||||
input_vec_dim=512, num_layers=24, prenet_layers=6, freeze_diff=True,
|
||||
unconditioned_percentage=.4)
|
||||
model.get_grad_norm_parameter_groups()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user