diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 02c07a72..43d29421 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -60,9 +60,9 @@ class SubBlock(nn.Module): class ConcatAttentionBlock(TimestepBlock): - def __init__(self, trunk_dim, contraction_dim, heads, dropout): + def __init__(self, trunk_dim, contraction_dim, time_embed_dim, heads, dropout): super().__init__() - self.prenorm = RMSScaleShiftNorm(trunk_dim, bias=False) + self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False) self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout) self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout) self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False) @@ -84,6 +84,7 @@ class TransformerDiffusion(nn.Module): self, prenet_channels=1024, prenet_layers=3, + time_embed_dim=256, model_channels=1024, contraction_dim=256, num_layers=8, @@ -103,6 +104,7 @@ class TransformerDiffusion(nn.Module): self.in_channels = in_channels self.model_channels = model_channels self.prenet_channels = prenet_channels + self.time_embed_dim = time_embed_dim self.out_channels = out_channels self.dropout = dropout self.unconditioned_percentage = unconditioned_percentage @@ -111,9 +113,9 @@ class TransformerDiffusion(nn.Module): self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) self.time_embed = nn.Sequential( - linear(prenet_channels, prenet_channels), + linear(time_embed_dim, time_embed_dim), nn.SiLU(), - linear(prenet_channels, model_channels), + linear(time_embed_dim, time_embed_dim), ) self.ar_prior = ar_prior @@ -150,7 +152,7 @@ class TransformerDiffusion(nn.Module): self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) self.intg = nn.Linear(prenet_channels*2, model_channels) - self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)]) + self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)]) self.out = nn.Sequential( normalization(model_channels), @@ -197,7 +199,7 @@ class TransformerDiffusion(nn.Module): unused_params.append(self.unconditioned_embedding) with torch.autocast(x.device.type, enabled=self.enable_fp16): - blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels)) + blk_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) x = self.inp_block(x).permute(0,2,1) rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) @@ -273,6 +275,7 @@ class TransformerDiffusionWithQuantizer(nn.Module): ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers])) blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) groups = { + 'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])), 'blk1_attention_layers': attn1, 'blk2_attention_layers': attn2, 'attention_layers': attn1 + attn2, @@ -291,6 +294,16 @@ class TransformerDiffusionWithQuantizer(nn.Module): } return groups + def before_step(self, step): + scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \ + list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])) + # Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes + # higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than + # directly fiddling with the gradients. + for p in scaled_grad_parameters: + p.grad *= .2 + + class TransformerDiffusionWithARPrior(nn.Module): def __init__(self, freeze_diff=False, **kwargs): @@ -353,8 +366,8 @@ def test_quant_model(): ts = torch.LongTensor([600, 600]) # For music: - model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=2048, contraction_dim=512, - prenet_channels=1024, num_heads=8, + model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=1536, contraction_dim=768, + prenet_channels=1024, num_heads=10, input_vec_dim=1024, num_layers=24, prenet_layers=4, dropout=.1) quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth') @@ -363,7 +376,18 @@ def test_quant_model(): print_network(model) o = model(clip, ts, clip) - model.get_grad_norm_parameter_groups() + pg = model.get_grad_norm_parameter_groups() + t = 0 + for k, vs in pg.items(): + s = 0 + for v in vs: + m = 1 + for d in v.shape: + m *= d + s += m + t += s + print(k, s/1000000) + print(t) def test_ar_model(): diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index 3b482e8b..7056a485 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -352,12 +352,13 @@ class RMSNorm(nn.Module): class RMSScaleShiftNorm(nn.Module): - def __init__(self, dim, eps=1e-8, bias=True): + def __init__(self, dim, embed_dim=None, eps=1e-8, bias=True): super().__init__() + embed_dim = default(embed_dim, dim) self.scale = dim ** -0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) - self.scale_shift_process = nn.Linear(dim, dim * 2, bias=bias) + self.scale_shift_process = nn.Linear(embed_dim, dim * 2, bias=bias) def forward(self, x, norm_scale_shift_inp): norm = torch.norm(x, dim=-1, keepdim=True) * self.scale