more reworks
This commit is contained in:
parent
7a36668870
commit
1fde3e5a08
|
@ -60,9 +60,9 @@ class SubBlock(nn.Module):
|
|||
|
||||
|
||||
class ConcatAttentionBlock(TimestepBlock):
|
||||
def __init__(self, trunk_dim, contraction_dim, heads, dropout):
|
||||
def __init__(self, trunk_dim, contraction_dim, time_embed_dim, heads, dropout):
|
||||
super().__init__()
|
||||
self.prenorm = RMSScaleShiftNorm(trunk_dim, bias=False)
|
||||
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
|
||||
self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout)
|
||||
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout)
|
||||
self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False)
|
||||
|
@ -84,6 +84,7 @@ class TransformerDiffusion(nn.Module):
|
|||
self,
|
||||
prenet_channels=1024,
|
||||
prenet_layers=3,
|
||||
time_embed_dim=256,
|
||||
model_channels=1024,
|
||||
contraction_dim=256,
|
||||
num_layers=8,
|
||||
|
@ -103,6 +104,7 @@ class TransformerDiffusion(nn.Module):
|
|||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.prenet_channels = prenet_channels
|
||||
self.time_embed_dim = time_embed_dim
|
||||
self.out_channels = out_channels
|
||||
self.dropout = dropout
|
||||
self.unconditioned_percentage = unconditioned_percentage
|
||||
|
@ -111,9 +113,9 @@ class TransformerDiffusion(nn.Module):
|
|||
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
|
||||
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(prenet_channels, prenet_channels),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(prenet_channels, model_channels),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.ar_prior = ar_prior
|
||||
|
@ -150,7 +152,7 @@ class TransformerDiffusion(nn.Module):
|
|||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
|
||||
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
||||
self.intg = nn.Linear(prenet_channels*2, model_channels)
|
||||
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)])
|
||||
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)])
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(model_channels),
|
||||
|
@ -197,7 +199,7 @@ class TransformerDiffusion(nn.Module):
|
|||
unused_params.append(self.unconditioned_embedding)
|
||||
|
||||
with torch.autocast(x.device.type, enabled=self.enable_fp16):
|
||||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels))
|
||||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
|
||||
x = self.inp_block(x).permute(0,2,1)
|
||||
|
||||
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
||||
|
@ -273,6 +275,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers]))
|
||||
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers]))
|
||||
groups = {
|
||||
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])),
|
||||
'blk1_attention_layers': attn1,
|
||||
'blk2_attention_layers': attn2,
|
||||
'attention_layers': attn1 + attn2,
|
||||
|
@ -291,6 +294,16 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
}
|
||||
return groups
|
||||
|
||||
def before_step(self, step):
|
||||
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
|
||||
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers]))
|
||||
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
|
||||
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
|
||||
# directly fiddling with the gradients.
|
||||
for p in scaled_grad_parameters:
|
||||
p.grad *= .2
|
||||
|
||||
|
||||
|
||||
class TransformerDiffusionWithARPrior(nn.Module):
|
||||
def __init__(self, freeze_diff=False, **kwargs):
|
||||
|
@ -353,8 +366,8 @@ def test_quant_model():
|
|||
ts = torch.LongTensor([600, 600])
|
||||
|
||||
# For music:
|
||||
model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=2048, contraction_dim=512,
|
||||
prenet_channels=1024, num_heads=8,
|
||||
model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=1536, contraction_dim=768,
|
||||
prenet_channels=1024, num_heads=10,
|
||||
input_vec_dim=1024, num_layers=24, prenet_layers=4,
|
||||
dropout=.1)
|
||||
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
|
||||
|
@ -363,7 +376,18 @@ def test_quant_model():
|
|||
|
||||
print_network(model)
|
||||
o = model(clip, ts, clip)
|
||||
model.get_grad_norm_parameter_groups()
|
||||
pg = model.get_grad_norm_parameter_groups()
|
||||
t = 0
|
||||
for k, vs in pg.items():
|
||||
s = 0
|
||||
for v in vs:
|
||||
m = 1
|
||||
for d in v.shape:
|
||||
m *= d
|
||||
s += m
|
||||
t += s
|
||||
print(k, s/1000000)
|
||||
print(t)
|
||||
|
||||
|
||||
def test_ar_model():
|
||||
|
|
|
@ -352,12 +352,13 @@ class RMSNorm(nn.Module):
|
|||
|
||||
|
||||
class RMSScaleShiftNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-8, bias=True):
|
||||
def __init__(self, dim, embed_dim=None, eps=1e-8, bias=True):
|
||||
super().__init__()
|
||||
embed_dim = default(embed_dim, dim)
|
||||
self.scale = dim ** -0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
self.scale_shift_process = nn.Linear(dim, dim * 2, bias=bias)
|
||||
self.scale_shift_process = nn.Linear(embed_dim, dim * 2, bias=bias)
|
||||
|
||||
def forward(self, x, norm_scale_shift_inp):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
|
|
Loading…
Reference in New Issue
Block a user