forked from mrq/DL-Art-School
fp16
This commit is contained in:
parent
aca9024d9b
commit
acfe9cf880
|
@ -175,13 +175,14 @@ class TransformerDiffusion(nn.Module):
|
||||||
code_emb = self.timestep_independent(codes, x.shape[-1])
|
code_emb = self.timestep_independent(codes, x.shape[-1])
|
||||||
unused_params.append(self.unconditioned_embedding)
|
unused_params.append(self.unconditioned_embedding)
|
||||||
|
|
||||||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels))
|
with torch.autocast(x.device.type, enabled=self.enable_fp16):
|
||||||
x = self.inp_block(x).permute(0,2,1)
|
blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels))
|
||||||
|
x = self.inp_block(x).permute(0,2,1)
|
||||||
|
|
||||||
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
||||||
x = self.intg(torch.cat([x, code_emb], dim=-1))
|
x = self.intg(torch.cat([x, code_emb], dim=-1))
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = checkpoint(layer, x, blk_emb, rotary_pos_emb)
|
x = checkpoint(layer, x, blk_emb, rotary_pos_emb)
|
||||||
|
|
||||||
x = x.float().permute(0,2,1)
|
x = x.float().permute(0,2,1)
|
||||||
out = self.out(x)
|
out = self.out(x)
|
||||||
|
@ -318,9 +319,9 @@ def test_quant_model():
|
||||||
clip = torch.randn(2, 256, 400)
|
clip = torch.randn(2, 256, 400)
|
||||||
cond = torch.randn(2, 256, 400)
|
cond = torch.randn(2, 256, 400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=2048, block_channels=1024,
|
model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=3072, block_channels=1536,
|
||||||
prenet_channels=1024, num_heads=8,
|
prenet_channels=1024, num_heads=12,
|
||||||
input_vec_dim=1024, num_layers=16, prenet_layers=6)
|
input_vec_dim=1024, num_layers=24, prenet_layers=6)
|
||||||
model.get_grad_norm_parameter_groups()
|
model.get_grad_norm_parameter_groups()
|
||||||
|
|
||||||
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
|
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
|
||||||
|
@ -335,8 +336,8 @@ def test_ar_model():
|
||||||
clip = torch.randn(2, 256, 400)
|
clip = torch.randn(2, 256, 400)
|
||||||
cond = torch.randn(2, 256, 400)
|
cond = torch.randn(2, 256, 400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1024,
|
model = TransformerDiffusionWithARPrior(model_channels=3072, block_channels=1536, prenet_channels=1536,
|
||||||
input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True,
|
input_vec_dim=512, num_layers=24, prenet_layers=6, freeze_diff=True,
|
||||||
unconditioned_percentage=.4)
|
unconditioned_percentage=.4)
|
||||||
model.get_grad_norm_parameter_groups()
|
model.get_grad_norm_parameter_groups()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user