forked from mrq/DL-Art-School
rework tfd13 further
- use a gated activation layer for both attention & convs - add a relativistic learned position bias. I believe this is similar to the T5 position encodings but it is simpler and learned - get rid of prepending to the attention matrix - this doesn't really work that well. the model eventually learns to attend one of its heads to these blocks but why not just concat if it is doing that?
This commit is contained in:
parent
40427de8e3
commit
ee8ceed6da
|
@ -341,6 +341,20 @@ class Downsample(nn.Module):
|
|||
return self.op(x)
|
||||
|
||||
|
||||
class cGLU(nn.Module):
|
||||
"""
|
||||
Gated GELU for channel-first architectures.
|
||||
"""
|
||||
def __init__(self, dim_in, dim_out=None):
|
||||
super().__init__()
|
||||
dim_out = dim_in if dim_out is None else dim_out
|
||||
self.proj = nn.Conv1d(dim_in, dim_out * 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
|
@ -439,7 +453,7 @@ class ResBlock(nn.Module):
|
|||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
def build_local_attention_mask(n, l, fixed_region):
|
||||
def build_local_attention_mask(n, l, fixed_region=0):
|
||||
"""
|
||||
Builds an attention mask that focuses attention on local region
|
||||
Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied.
|
||||
|
@ -465,6 +479,24 @@ def test_local_attention_mask():
|
|||
print(build_local_attention_mask(9,4,1))
|
||||
|
||||
|
||||
class RelativeQKBias(nn.Module):
|
||||
"""
|
||||
Very simple relative position bias scheme which should be directly added to QK matrix. This bias simply applies to
|
||||
the distance from the given element.
|
||||
"""
|
||||
def __init__(self, l, max_positions=4000):
|
||||
super().__init__()
|
||||
self.emb = nn.Parameter(torch.randn(l+1) * .01)
|
||||
o = torch.arange(0,max_positions)
|
||||
c = o.unsqueeze(-1).repeat(1,max_positions)
|
||||
r = o.unsqueeze(0).repeat(max_positions,1)
|
||||
M = ((-(r-c).abs())+l).clamp(0,l)
|
||||
self.register_buffer('M', M, persistent=False)
|
||||
|
||||
def forward(self, n):
|
||||
return self.emb[self.M[:n, :n]].view(1,n,n)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
@ -507,16 +539,22 @@ class AttentionBlock(nn.Module):
|
|||
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1)
|
||||
self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1))
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
def forward(self, x, mask=None, qk_bias=None):
|
||||
if self.do_checkpoint:
|
||||
if mask is not None:
|
||||
return checkpoint(self._forward, x, mask)
|
||||
if mask is None:
|
||||
if qk_bias is None:
|
||||
return checkpoint(self._forward, x)
|
||||
else:
|
||||
assert False, 'unsupported: qk_bias but no mask'
|
||||
else:
|
||||
return checkpoint(self._forward, x)
|
||||
if qk_bias is None:
|
||||
return checkpoint(self._forward, x, mask)
|
||||
else:
|
||||
return checkpoint(self._forward, x, mask, qk_bias)
|
||||
else:
|
||||
return self._forward(x, mask)
|
||||
|
||||
def _forward(self, x, mask=None):
|
||||
def _forward(self, x, mask=None, qk_bias=0):
|
||||
b, c, *spatial = x.shape
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
|
@ -529,7 +567,7 @@ class AttentionBlock(nn.Module):
|
|||
if self.do_activation:
|
||||
x = F.silu(x, inplace=True)
|
||||
qkv = self.qkv(x)
|
||||
h = self.attention(qkv, mask)
|
||||
h = self.attention(qkv, mask, qk_bias)
|
||||
h = self.proj_out(h)
|
||||
xp = self.x_proj(x)
|
||||
return (xp + h).reshape(b, xp.shape[1], *spatial)
|
||||
|
@ -544,7 +582,7 @@ class QKVAttentionLegacy(nn.Module):
|
|||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv, mask=None):
|
||||
def forward(self, qkv, mask=None, qk_bias=0):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
|
@ -559,6 +597,7 @@ class QKVAttentionLegacy(nn.Module):
|
|||
weight = torch.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = weight + qk_bias
|
||||
if mask is not None:
|
||||
mask = mask.repeat(self.n_heads, 1, 1)
|
||||
weight[mask.logical_not()] = -torch.inf
|
||||
|
@ -577,7 +616,7 @@ class QKVAttention(nn.Module):
|
|||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv, mask=None):
|
||||
def forward(self, qkv, mask=None, qk_bias=0):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
|
|
|
@ -6,7 +6,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask
|
||||
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \
|
||||
RelativeQKBias
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepBlock
|
||||
from trainer.networks import register_model
|
||||
|
@ -22,73 +23,73 @@ def is_sequence(t):
|
|||
|
||||
|
||||
class SubBlock(nn.Module):
|
||||
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout):
|
||||
def __init__(self, inp_dim, contraction_dim, heads, dropout):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.blk_emb_proj = nn.Conv1d(blk_dim, inp_dim, 1)
|
||||
self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads)
|
||||
self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False)
|
||||
self.pos_bias = RelativeQKBias(l=64)
|
||||
self.attn_glu = cGLU(contraction_dim)
|
||||
self.attnorm = nn.GroupNorm(8, contraction_dim)
|
||||
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
|
||||
self.ff_glu = cGLU(contraction_dim)
|
||||
self.ffnorm = nn.GroupNorm(8, contraction_dim)
|
||||
self.mask = build_local_attention_mask(n=4000, l=64, fixed_region=8)
|
||||
self.mask_initialized = False
|
||||
|
||||
def forward(self, x, blk_emb):
|
||||
if self.mask is not None and not self.mask_initialized:
|
||||
self.mask = self.mask.to(x.device)
|
||||
self.mask_initialized = True
|
||||
blk_enc = self.blk_emb_proj(blk_emb)
|
||||
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
|
||||
ah = ah[:,:,blk_enc.shape[-1]:] # Strip off the blk_emc used for attention and re-align with x.
|
||||
ah = F.gelu(self.attnorm(ah))
|
||||
def forward(self, x):
|
||||
ah = self.dropout(self.attn(x, mask=self.mask, qk_bias=self.pos_bias(x.shape[-1])))
|
||||
ah = self.attn_glu(self.attnorm(ah))
|
||||
h = torch.cat([ah, x], dim=1)
|
||||
hf = self.dropout(checkpoint(self.ff, h))
|
||||
hf = F.gelu(self.ffnorm(hf))
|
||||
hf = self.ff_glu(self.ffnorm(hf))
|
||||
h = torch.cat([h, hf], dim=1)
|
||||
return h
|
||||
|
||||
|
||||
class ConcatAttentionBlock(TimestepBlock):
|
||||
def __init__(self, trunk_dim, contraction_dim, heads, dropout):
|
||||
def __init__(self, trunk_dim, contraction_dim, blk_dim, heads, dropout):
|
||||
super().__init__()
|
||||
self.contraction_dim = contraction_dim
|
||||
self.prenorm = nn.GroupNorm(8, trunk_dim)
|
||||
self.block1 = SubBlock(trunk_dim, contraction_dim, trunk_dim, heads, dropout)
|
||||
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, trunk_dim, heads, dropout)
|
||||
self.block1 = SubBlock(trunk_dim+blk_dim, contraction_dim, heads, dropout)
|
||||
self.block2 = SubBlock(trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout)
|
||||
self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False)
|
||||
self.out.weight.data.zero_()
|
||||
|
||||
def forward(self, x, blk_emb):
|
||||
h = self.prenorm(x)
|
||||
h = self.block1(h, blk_emb)
|
||||
h = self.block2(h, blk_emb)
|
||||
h = self.out(h[:,x.shape[1]:])
|
||||
h = torch.cat([h, blk_emb.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1)
|
||||
h = self.block1(h)
|
||||
h = self.block2(h)
|
||||
h = self.out(h[:,-self.contraction_dim*4:])
|
||||
return h + x
|
||||
|
||||
|
||||
class ConditioningEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
spec_dim,
|
||||
embedding_dim,
|
||||
hidden_dim,
|
||||
out_dim,
|
||||
num_resolutions,
|
||||
attn_blocks=6,
|
||||
num_attn_heads=4,
|
||||
do_checkpointing=False):
|
||||
super().__init__()
|
||||
attn = []
|
||||
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=5, stride=2)
|
||||
self.resolution_embedding = nn.Embedding(num_resolutions, embedding_dim)
|
||||
self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2)
|
||||
self.resolution_embedding = nn.Embedding(num_resolutions, hidden_dim)
|
||||
self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start.
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
attn.append(ResBlock(embedding_dim, dims=1, checkpointing_enabled=do_checkpointing))
|
||||
attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
self.out = nn.Linear(hidden_dim, out_dim, bias=False)
|
||||
self.dim = hidden_dim
|
||||
self.do_checkpointing = do_checkpointing
|
||||
|
||||
def forward(self, x, resolution):
|
||||
h = self.init(x) + self.resolution_embedding(resolution).unsqueeze(-1)
|
||||
h = self.attn(h)
|
||||
return h[:, :, :5]
|
||||
return self.out(h[:, :, 0])
|
||||
|
||||
|
||||
class TransformerDiffusion(nn.Module):
|
||||
|
@ -97,7 +98,6 @@ class TransformerDiffusion(nn.Module):
|
|||
"""
|
||||
def __init__(
|
||||
self,
|
||||
time_embed_dim=256,
|
||||
resolution_steps=8,
|
||||
max_window=384,
|
||||
model_channels=1024,
|
||||
|
@ -106,6 +106,9 @@ class TransformerDiffusion(nn.Module):
|
|||
in_channels=256,
|
||||
input_vec_dim=1024,
|
||||
out_channels=512, # mean and variance
|
||||
time_embed_dim=256,
|
||||
time_proj_dim=64,
|
||||
cond_proj_dim=256,
|
||||
num_heads=4,
|
||||
dropout=0,
|
||||
use_fp16=False,
|
||||
|
@ -128,19 +131,20 @@ class TransformerDiffusion(nn.Module):
|
|||
self.time_embed = nn.Sequential(
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, model_channels),
|
||||
linear(time_embed_dim, time_proj_dim),
|
||||
)
|
||||
self.prior_time_embed = nn.Sequential(
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, model_channels),
|
||||
linear(time_embed_dim, time_proj_dim),
|
||||
)
|
||||
self.resolution_embed = nn.Embedding(resolution_steps, model_channels)
|
||||
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, resolution_steps, num_attn_heads=model_channels//64)
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,5))
|
||||
self.resolution_embed = nn.Embedding(resolution_steps, time_proj_dim)
|
||||
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))
|
||||
|
||||
self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1)
|
||||
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)])
|
||||
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim,
|
||||
num_heads, dropout) for _ in range(num_layers)])
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(model_channels),
|
||||
|
@ -246,15 +250,14 @@ class TransformerDiffusion(nn.Module):
|
|||
|
||||
# Mask out the conditioning input and x_prior inputs for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((x.shape[0], 1, 1),
|
||||
device=x.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, 1), code_emb)
|
||||
unconditioned_batches = torch.rand((x.shape[0], 1), device=x.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1), code_emb)
|
||||
|
||||
with torch.autocast(x.device.type, enabled=self.enable_fp16):
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
|
||||
prior_time_emb = self.prior_time_embed(timestep_embedding(prior_timesteps, self.time_embed_dim))
|
||||
res_emb = self.resolution_embed(resolution)
|
||||
blk_emb = torch.cat([time_emb.unsqueeze(-1), prior_time_emb.unsqueeze(-1), res_emb.unsqueeze(-1), code_emb], dim=-1)
|
||||
blk_emb = torch.cat([time_emb, prior_time_emb, res_emb, code_emb], dim=1)
|
||||
|
||||
h = torch.cat([x, x_prior], dim=1)
|
||||
h = self.inp_block(h)
|
||||
|
@ -304,5 +307,5 @@ def remove_conditioning(sd_path):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth')
|
||||
#remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth')
|
||||
test_tfd()
|
||||
|
|
|
@ -146,7 +146,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
# x = x.clamp(-s, s) / s
|
||||
# return x
|
||||
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
|
||||
gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm})
|
||||
gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm}, eta=.8)
|
||||
|
||||
gen_mel_denorm = denormalize_torch_mel(gen_mel)
|
||||
output_shape = (1,16,audio.shape[-1]//16)
|
||||
|
@ -230,7 +230,6 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
audio = audio.unsqueeze(0)
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
mel_norm = normalize_torch_mel(mel)
|
||||
#mel_norm = mel_norm[:,:,:448*4] # restricts first stage to optimal training window.
|
||||
conditioning = mel_norm[:,:,:1200]
|
||||
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='nearest')
|
||||
stage1_shape = (1, 256, downsampled.shape[-1]*4)
|
||||
|
@ -323,19 +322,34 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
# For multilevel SR:
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator',
|
||||
also_load_savepoint=False, strict_load=False,
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\22000_generator.pth'
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\4000_generator.pth'
|
||||
).cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
||||
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
||||
'diffusion_steps': 128, # basis: 192
|
||||
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'chained_sr',
|
||||
#'causal': True, 'causal_slope': 4,
|
||||
#'partial_low': 128, 'partial_high': 192
|
||||
'diffusion_schedule': 'cosine', 'diffusion_type': 'chained_sr',
|
||||
}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}}
|
||||
"""
|
||||
|
||||
# For TFD+cheater trainer
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater.yml', 'generator',
|
||||
also_load_savepoint=False, strict_load=False,
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd14_and_cheater_g2\\models\\20000_generator.pth'
|
||||
).cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
||||
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
||||
'diffusion_steps': 128, # basis: 192
|
||||
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': True,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant',
|
||||
}
|
||||
|
||||
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 6, 'device': 'cuda', 'opt': {}}
|
||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||
fds = []
|
||||
for i in range(2):
|
||||
|
|
Loading…
Reference in New Issue
Block a user