forked from mrq/DL-Art-School
more reworks
This commit is contained in:
parent
7a36668870
commit
1fde3e5a08
|
@ -60,9 +60,9 @@ class SubBlock(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ConcatAttentionBlock(TimestepBlock):
|
class ConcatAttentionBlock(TimestepBlock):
|
||||||
def __init__(self, trunk_dim, contraction_dim, heads, dropout):
|
def __init__(self, trunk_dim, contraction_dim, time_embed_dim, heads, dropout):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.prenorm = RMSScaleShiftNorm(trunk_dim, bias=False)
|
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
|
||||||
self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout)
|
self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout)
|
||||||
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout)
|
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout)
|
||||||
self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False)
|
self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False)
|
||||||
|
@ -84,6 +84,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
self,
|
self,
|
||||||
prenet_channels=1024,
|
prenet_channels=1024,
|
||||||
prenet_layers=3,
|
prenet_layers=3,
|
||||||
|
time_embed_dim=256,
|
||||||
model_channels=1024,
|
model_channels=1024,
|
||||||
contraction_dim=256,
|
contraction_dim=256,
|
||||||
num_layers=8,
|
num_layers=8,
|
||||||
|
@ -103,6 +104,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.model_channels = model_channels
|
self.model_channels = model_channels
|
||||||
self.prenet_channels = prenet_channels
|
self.prenet_channels = prenet_channels
|
||||||
|
self.time_embed_dim = time_embed_dim
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.unconditioned_percentage = unconditioned_percentage
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
|
@ -111,9 +113,9 @@ class TransformerDiffusion(nn.Module):
|
||||||
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
|
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
|
||||||
|
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
linear(prenet_channels, prenet_channels),
|
linear(time_embed_dim, time_embed_dim),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(prenet_channels, model_channels),
|
linear(time_embed_dim, time_embed_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ar_prior = ar_prior
|
self.ar_prior = ar_prior
|
||||||
|
@ -150,7 +152,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
|
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
|
||||||
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
||||||
self.intg = nn.Linear(prenet_channels*2, model_channels)
|
self.intg = nn.Linear(prenet_channels*2, model_channels)
|
||||||
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)])
|
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)])
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(model_channels),
|
normalization(model_channels),
|
||||||
|
@ -197,7 +199,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
unused_params.append(self.unconditioned_embedding)
|
unused_params.append(self.unconditioned_embedding)
|
||||||
|
|
||||||
with torch.autocast(x.device.type, enabled=self.enable_fp16):
|
with torch.autocast(x.device.type, enabled=self.enable_fp16):
|
||||||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels))
|
blk_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
|
||||||
x = self.inp_block(x).permute(0,2,1)
|
x = self.inp_block(x).permute(0,2,1)
|
||||||
|
|
||||||
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
||||||
|
@ -273,6 +275,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers]))
|
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers]))
|
||||||
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers]))
|
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers]))
|
||||||
groups = {
|
groups = {
|
||||||
|
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])),
|
||||||
'blk1_attention_layers': attn1,
|
'blk1_attention_layers': attn1,
|
||||||
'blk2_attention_layers': attn2,
|
'blk2_attention_layers': attn2,
|
||||||
'attention_layers': attn1 + attn2,
|
'attention_layers': attn1 + attn2,
|
||||||
|
@ -291,6 +294,16 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
}
|
}
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
|
def before_step(self, step):
|
||||||
|
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
|
||||||
|
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers]))
|
||||||
|
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
|
||||||
|
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
|
||||||
|
# directly fiddling with the gradients.
|
||||||
|
for p in scaled_grad_parameters:
|
||||||
|
p.grad *= .2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDiffusionWithARPrior(nn.Module):
|
class TransformerDiffusionWithARPrior(nn.Module):
|
||||||
def __init__(self, freeze_diff=False, **kwargs):
|
def __init__(self, freeze_diff=False, **kwargs):
|
||||||
|
@ -353,8 +366,8 @@ def test_quant_model():
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
|
|
||||||
# For music:
|
# For music:
|
||||||
model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=2048, contraction_dim=512,
|
model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=1536, contraction_dim=768,
|
||||||
prenet_channels=1024, num_heads=8,
|
prenet_channels=1024, num_heads=10,
|
||||||
input_vec_dim=1024, num_layers=24, prenet_layers=4,
|
input_vec_dim=1024, num_layers=24, prenet_layers=4,
|
||||||
dropout=.1)
|
dropout=.1)
|
||||||
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
|
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
|
||||||
|
@ -363,7 +376,18 @@ def test_quant_model():
|
||||||
|
|
||||||
print_network(model)
|
print_network(model)
|
||||||
o = model(clip, ts, clip)
|
o = model(clip, ts, clip)
|
||||||
model.get_grad_norm_parameter_groups()
|
pg = model.get_grad_norm_parameter_groups()
|
||||||
|
t = 0
|
||||||
|
for k, vs in pg.items():
|
||||||
|
s = 0
|
||||||
|
for v in vs:
|
||||||
|
m = 1
|
||||||
|
for d in v.shape:
|
||||||
|
m *= d
|
||||||
|
s += m
|
||||||
|
t += s
|
||||||
|
print(k, s/1000000)
|
||||||
|
print(t)
|
||||||
|
|
||||||
|
|
||||||
def test_ar_model():
|
def test_ar_model():
|
||||||
|
|
|
@ -352,12 +352,13 @@ class RMSNorm(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class RMSScaleShiftNorm(nn.Module):
|
class RMSScaleShiftNorm(nn.Module):
|
||||||
def __init__(self, dim, eps=1e-8, bias=True):
|
def __init__(self, dim, embed_dim=None, eps=1e-8, bias=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
embed_dim = default(embed_dim, dim)
|
||||||
self.scale = dim ** -0.5
|
self.scale = dim ** -0.5
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.g = nn.Parameter(torch.ones(dim))
|
self.g = nn.Parameter(torch.ones(dim))
|
||||||
self.scale_shift_process = nn.Linear(dim, dim * 2, bias=bias)
|
self.scale_shift_process = nn.Linear(embed_dim, dim * 2, bias=bias)
|
||||||
|
|
||||||
def forward(self, x, norm_scale_shift_inp):
|
def forward(self, x, norm_scale_shift_inp):
|
||||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||||
|
|
Loading…
Reference in New Issue
Block a user