more reworks

This commit is contained in:
James Betker 2022-06-13 08:40:23 -06:00
parent 7a36668870
commit 1fde3e5a08
2 changed files with 36 additions and 11 deletions

View File

@ -60,9 +60,9 @@ class SubBlock(nn.Module):
class ConcatAttentionBlock(TimestepBlock): class ConcatAttentionBlock(TimestepBlock):
def __init__(self, trunk_dim, contraction_dim, heads, dropout): def __init__(self, trunk_dim, contraction_dim, time_embed_dim, heads, dropout):
super().__init__() super().__init__()
self.prenorm = RMSScaleShiftNorm(trunk_dim, bias=False) self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout) self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout)
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout) self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout)
self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False) self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False)
@ -84,6 +84,7 @@ class TransformerDiffusion(nn.Module):
self, self,
prenet_channels=1024, prenet_channels=1024,
prenet_layers=3, prenet_layers=3,
time_embed_dim=256,
model_channels=1024, model_channels=1024,
contraction_dim=256, contraction_dim=256,
num_layers=8, num_layers=8,
@ -103,6 +104,7 @@ class TransformerDiffusion(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.model_channels = model_channels self.model_channels = model_channels
self.prenet_channels = prenet_channels self.prenet_channels = prenet_channels
self.time_embed_dim = time_embed_dim
self.out_channels = out_channels self.out_channels = out_channels
self.dropout = dropout self.dropout = dropout
self.unconditioned_percentage = unconditioned_percentage self.unconditioned_percentage = unconditioned_percentage
@ -111,9 +113,9 @@ class TransformerDiffusion(nn.Module):
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(prenet_channels, prenet_channels), linear(time_embed_dim, time_embed_dim),
nn.SiLU(), nn.SiLU(),
linear(prenet_channels, model_channels), linear(time_embed_dim, time_embed_dim),
) )
self.ar_prior = ar_prior self.ar_prior = ar_prior
@ -150,7 +152,7 @@ class TransformerDiffusion(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.intg = nn.Linear(prenet_channels*2, model_channels) self.intg = nn.Linear(prenet_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)]) self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)])
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(model_channels), normalization(model_channels),
@ -197,7 +199,7 @@ class TransformerDiffusion(nn.Module):
unused_params.append(self.unconditioned_embedding) unused_params.append(self.unconditioned_embedding)
with torch.autocast(x.device.type, enabled=self.enable_fp16): with torch.autocast(x.device.type, enabled=self.enable_fp16):
blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels)) blk_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
x = self.inp_block(x).permute(0,2,1) x = self.inp_block(x).permute(0,2,1)
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
@ -273,6 +275,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers])) ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers]))
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers]))
groups = { groups = {
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])),
'blk1_attention_layers': attn1, 'blk1_attention_layers': attn1,
'blk2_attention_layers': attn2, 'blk2_attention_layers': attn2,
'attention_layers': attn1 + attn2, 'attention_layers': attn1 + attn2,
@ -291,6 +294,16 @@ class TransformerDiffusionWithQuantizer(nn.Module):
} }
return groups return groups
def before_step(self, step):
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers]))
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
for p in scaled_grad_parameters:
p.grad *= .2
class TransformerDiffusionWithARPrior(nn.Module): class TransformerDiffusionWithARPrior(nn.Module):
def __init__(self, freeze_diff=False, **kwargs): def __init__(self, freeze_diff=False, **kwargs):
@ -353,8 +366,8 @@ def test_quant_model():
ts = torch.LongTensor([600, 600]) ts = torch.LongTensor([600, 600])
# For music: # For music:
model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=2048, contraction_dim=512, model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=1536, contraction_dim=768,
prenet_channels=1024, num_heads=8, prenet_channels=1024, num_heads=10,
input_vec_dim=1024, num_layers=24, prenet_layers=4, input_vec_dim=1024, num_layers=24, prenet_layers=4,
dropout=.1) dropout=.1)
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth') quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
@ -363,7 +376,18 @@ def test_quant_model():
print_network(model) print_network(model)
o = model(clip, ts, clip) o = model(clip, ts, clip)
model.get_grad_norm_parameter_groups() pg = model.get_grad_norm_parameter_groups()
t = 0
for k, vs in pg.items():
s = 0
for v in vs:
m = 1
for d in v.shape:
m *= d
s += m
t += s
print(k, s/1000000)
print(t)
def test_ar_model(): def test_ar_model():

View File

@ -352,12 +352,13 @@ class RMSNorm(nn.Module):
class RMSScaleShiftNorm(nn.Module): class RMSScaleShiftNorm(nn.Module):
def __init__(self, dim, eps=1e-8, bias=True): def __init__(self, dim, embed_dim=None, eps=1e-8, bias=True):
super().__init__() super().__init__()
embed_dim = default(embed_dim, dim)
self.scale = dim ** -0.5 self.scale = dim ** -0.5
self.eps = eps self.eps = eps
self.g = nn.Parameter(torch.ones(dim)) self.g = nn.Parameter(torch.ones(dim))
self.scale_shift_process = nn.Linear(dim, dim * 2, bias=bias) self.scale_shift_process = nn.Linear(embed_dim, dim * 2, bias=bias)
def forward(self, x, norm_scale_shift_inp): def forward(self, x, norm_scale_shift_inp):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale norm = torch.norm(x, dim=-1, keepdim=True) * self.scale