more reworks

This commit is contained in:
James Betker 2022-06-13 08:40:23 -06:00
parent 7a36668870
commit 1fde3e5a08
2 changed files with 36 additions and 11 deletions

View File

@ -60,9 +60,9 @@ class SubBlock(nn.Module):
class ConcatAttentionBlock(TimestepBlock):
def __init__(self, trunk_dim, contraction_dim, heads, dropout):
def __init__(self, trunk_dim, contraction_dim, time_embed_dim, heads, dropout):
super().__init__()
self.prenorm = RMSScaleShiftNorm(trunk_dim, bias=False)
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout)
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout)
self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False)
@ -84,6 +84,7 @@ class TransformerDiffusion(nn.Module):
self,
prenet_channels=1024,
prenet_layers=3,
time_embed_dim=256,
model_channels=1024,
contraction_dim=256,
num_layers=8,
@ -103,6 +104,7 @@ class TransformerDiffusion(nn.Module):
self.in_channels = in_channels
self.model_channels = model_channels
self.prenet_channels = prenet_channels
self.time_embed_dim = time_embed_dim
self.out_channels = out_channels
self.dropout = dropout
self.unconditioned_percentage = unconditioned_percentage
@ -111,9 +113,9 @@ class TransformerDiffusion(nn.Module):
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
self.time_embed = nn.Sequential(
linear(prenet_channels, prenet_channels),
linear(time_embed_dim, time_embed_dim),
nn.SiLU(),
linear(prenet_channels, model_channels),
linear(time_embed_dim, time_embed_dim),
)
self.ar_prior = ar_prior
@ -150,7 +152,7 @@ class TransformerDiffusion(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.intg = nn.Linear(prenet_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)])
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)])
self.out = nn.Sequential(
normalization(model_channels),
@ -197,7 +199,7 @@ class TransformerDiffusion(nn.Module):
unused_params.append(self.unconditioned_embedding)
with torch.autocast(x.device.type, enabled=self.enable_fp16):
blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels))
blk_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
x = self.inp_block(x).permute(0,2,1)
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
@ -273,6 +275,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers]))
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers]))
groups = {
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])),
'blk1_attention_layers': attn1,
'blk2_attention_layers': attn2,
'attention_layers': attn1 + attn2,
@ -291,6 +294,16 @@ class TransformerDiffusionWithQuantizer(nn.Module):
}
return groups
def before_step(self, step):
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers]))
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
for p in scaled_grad_parameters:
p.grad *= .2
class TransformerDiffusionWithARPrior(nn.Module):
def __init__(self, freeze_diff=False, **kwargs):
@ -353,8 +366,8 @@ def test_quant_model():
ts = torch.LongTensor([600, 600])
# For music:
model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=2048, contraction_dim=512,
prenet_channels=1024, num_heads=8,
model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=1536, contraction_dim=768,
prenet_channels=1024, num_heads=10,
input_vec_dim=1024, num_layers=24, prenet_layers=4,
dropout=.1)
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
@ -363,7 +376,18 @@ def test_quant_model():
print_network(model)
o = model(clip, ts, clip)
model.get_grad_norm_parameter_groups()
pg = model.get_grad_norm_parameter_groups()
t = 0
for k, vs in pg.items():
s = 0
for v in vs:
m = 1
for d in v.shape:
m *= d
s += m
t += s
print(k, s/1000000)
print(t)
def test_ar_model():

View File

@ -352,12 +352,13 @@ class RMSNorm(nn.Module):
class RMSScaleShiftNorm(nn.Module):
def __init__(self, dim, eps=1e-8, bias=True):
def __init__(self, dim, embed_dim=None, eps=1e-8, bias=True):
super().__init__()
embed_dim = default(embed_dim, dim)
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
self.scale_shift_process = nn.Linear(dim, dim * 2, bias=bias)
self.scale_shift_process = nn.Linear(embed_dim, dim * 2, bias=bias)
def forward(self, x, norm_scale_shift_inp):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale