rework tfd13 further

- use a gated activation layer for both attention & convs
- add a relativistic learned position bias. I believe this is similar to the T5 position encodings but it is simpler and learned
- get rid of prepending to the attention matrix - this doesn't really work that well. the model eventually learns to attend one of its heads to these blocks but why not just concat if it is doing that?
This commit is contained in:
James Betker 2022-07-20 23:28:29 -06:00
parent 40427de8e3
commit ee8ceed6da
3 changed files with 111 additions and 55 deletions

View File

@ -341,6 +341,20 @@ class Downsample(nn.Module):
return self.op(x)
class cGLU(nn.Module):
"""
Gated GELU for channel-first architectures.
"""
def __init__(self, dim_in, dim_out=None):
super().__init__()
dim_out = dim_in if dim_out is None else dim_out
self.proj = nn.Conv1d(dim_in, dim_out * 2, 1)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=1)
return x * F.gelu(gate)
class ResBlock(nn.Module):
"""
A residual block that can optionally change the number of channels.
@ -439,7 +453,7 @@ class ResBlock(nn.Module):
return self.skip_connection(x) + h
def build_local_attention_mask(n, l, fixed_region):
def build_local_attention_mask(n, l, fixed_region=0):
"""
Builds an attention mask that focuses attention on local region
Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied.
@ -465,6 +479,24 @@ def test_local_attention_mask():
print(build_local_attention_mask(9,4,1))
class RelativeQKBias(nn.Module):
"""
Very simple relative position bias scheme which should be directly added to QK matrix. This bias simply applies to
the distance from the given element.
"""
def __init__(self, l, max_positions=4000):
super().__init__()
self.emb = nn.Parameter(torch.randn(l+1) * .01)
o = torch.arange(0,max_positions)
c = o.unsqueeze(-1).repeat(1,max_positions)
r = o.unsqueeze(0).repeat(max_positions,1)
M = ((-(r-c).abs())+l).clamp(0,l)
self.register_buffer('M', M, persistent=False)
def forward(self, n):
return self.emb[self.M[:n, :n]].view(1,n,n)
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
@ -507,16 +539,22 @@ class AttentionBlock(nn.Module):
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1)
self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1))
def forward(self, x, mask=None):
def forward(self, x, mask=None, qk_bias=None):
if self.do_checkpoint:
if mask is not None:
return checkpoint(self._forward, x, mask)
if mask is None:
if qk_bias is None:
return checkpoint(self._forward, x)
else:
assert False, 'unsupported: qk_bias but no mask'
else:
return checkpoint(self._forward, x)
if qk_bias is None:
return checkpoint(self._forward, x, mask)
else:
return checkpoint(self._forward, x, mask, qk_bias)
else:
return self._forward(x, mask)
def _forward(self, x, mask=None):
def _forward(self, x, mask=None, qk_bias=0):
b, c, *spatial = x.shape
if mask is not None:
if len(mask.shape) == 2:
@ -529,7 +567,7 @@ class AttentionBlock(nn.Module):
if self.do_activation:
x = F.silu(x, inplace=True)
qkv = self.qkv(x)
h = self.attention(qkv, mask)
h = self.attention(qkv, mask, qk_bias)
h = self.proj_out(h)
xp = self.x_proj(x)
return (xp + h).reshape(b, xp.shape[1], *spatial)
@ -544,7 +582,7 @@ class QKVAttentionLegacy(nn.Module):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, mask=None):
def forward(self, qkv, mask=None, qk_bias=0):
"""
Apply QKV attention.
@ -559,6 +597,7 @@ class QKVAttentionLegacy(nn.Module):
weight = torch.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = weight + qk_bias
if mask is not None:
mask = mask.repeat(self.n_heads, 1, 1)
weight[mask.logical_not()] = -torch.inf
@ -577,7 +616,7 @@ class QKVAttention(nn.Module):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, mask=None):
def forward(self, qkv, mask=None, qk_bias=0):
"""
Apply QKV attention.

View File

@ -6,7 +6,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \
RelativeQKBias
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock
from trainer.networks import register_model
@ -22,73 +23,73 @@ def is_sequence(t):
class SubBlock(nn.Module):
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout):
def __init__(self, inp_dim, contraction_dim, heads, dropout):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
self.blk_emb_proj = nn.Conv1d(blk_dim, inp_dim, 1)
self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads)
self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False)
self.pos_bias = RelativeQKBias(l=64)
self.attn_glu = cGLU(contraction_dim)
self.attnorm = nn.GroupNorm(8, contraction_dim)
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
self.ff_glu = cGLU(contraction_dim)
self.ffnorm = nn.GroupNorm(8, contraction_dim)
self.mask = build_local_attention_mask(n=4000, l=64, fixed_region=8)
self.mask_initialized = False
def forward(self, x, blk_emb):
if self.mask is not None and not self.mask_initialized:
self.mask = self.mask.to(x.device)
self.mask_initialized = True
blk_enc = self.blk_emb_proj(blk_emb)
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
ah = ah[:,:,blk_enc.shape[-1]:] # Strip off the blk_emc used for attention and re-align with x.
ah = F.gelu(self.attnorm(ah))
def forward(self, x):
ah = self.dropout(self.attn(x, mask=self.mask, qk_bias=self.pos_bias(x.shape[-1])))
ah = self.attn_glu(self.attnorm(ah))
h = torch.cat([ah, x], dim=1)
hf = self.dropout(checkpoint(self.ff, h))
hf = F.gelu(self.ffnorm(hf))
hf = self.ff_glu(self.ffnorm(hf))
h = torch.cat([h, hf], dim=1)
return h
class ConcatAttentionBlock(TimestepBlock):
def __init__(self, trunk_dim, contraction_dim, heads, dropout):
def __init__(self, trunk_dim, contraction_dim, blk_dim, heads, dropout):
super().__init__()
self.contraction_dim = contraction_dim
self.prenorm = nn.GroupNorm(8, trunk_dim)
self.block1 = SubBlock(trunk_dim, contraction_dim, trunk_dim, heads, dropout)
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, trunk_dim, heads, dropout)
self.block1 = SubBlock(trunk_dim+blk_dim, contraction_dim, heads, dropout)
self.block2 = SubBlock(trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout)
self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False)
self.out.weight.data.zero_()
def forward(self, x, blk_emb):
h = self.prenorm(x)
h = self.block1(h, blk_emb)
h = self.block2(h, blk_emb)
h = self.out(h[:,x.shape[1]:])
h = torch.cat([h, blk_emb.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1)
h = self.block1(h)
h = self.block2(h)
h = self.out(h[:,-self.contraction_dim*4:])
return h + x
class ConditioningEncoder(nn.Module):
def __init__(self,
spec_dim,
embedding_dim,
hidden_dim,
out_dim,
num_resolutions,
attn_blocks=6,
num_attn_heads=4,
do_checkpointing=False):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=5, stride=2)
self.resolution_embedding = nn.Embedding(num_resolutions, embedding_dim)
self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2)
self.resolution_embedding = nn.Embedding(num_resolutions, hidden_dim)
self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start.
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
attn.append(ResBlock(embedding_dim, dims=1, checkpointing_enabled=do_checkpointing))
attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing))
attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.out = nn.Linear(hidden_dim, out_dim, bias=False)
self.dim = hidden_dim
self.do_checkpointing = do_checkpointing
def forward(self, x, resolution):
h = self.init(x) + self.resolution_embedding(resolution).unsqueeze(-1)
h = self.attn(h)
return h[:, :, :5]
return self.out(h[:, :, 0])
class TransformerDiffusion(nn.Module):
@ -97,7 +98,6 @@ class TransformerDiffusion(nn.Module):
"""
def __init__(
self,
time_embed_dim=256,
resolution_steps=8,
max_window=384,
model_channels=1024,
@ -106,6 +106,9 @@ class TransformerDiffusion(nn.Module):
in_channels=256,
input_vec_dim=1024,
out_channels=512, # mean and variance
time_embed_dim=256,
time_proj_dim=64,
cond_proj_dim=256,
num_heads=4,
dropout=0,
use_fp16=False,
@ -128,19 +131,20 @@ class TransformerDiffusion(nn.Module):
self.time_embed = nn.Sequential(
linear(time_embed_dim, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, model_channels),
linear(time_embed_dim, time_proj_dim),
)
self.prior_time_embed = nn.Sequential(
linear(time_embed_dim, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, model_channels),
linear(time_embed_dim, time_proj_dim),
)
self.resolution_embed = nn.Embedding(resolution_steps, model_channels)
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, resolution_steps, num_attn_heads=model_channels//64)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,5))
self.resolution_embed = nn.Embedding(resolution_steps, time_proj_dim)
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))
self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1)
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)])
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim,
num_heads, dropout) for _ in range(num_layers)])
self.out = nn.Sequential(
normalization(model_channels),
@ -246,15 +250,14 @@ class TransformerDiffusion(nn.Module):
# Mask out the conditioning input and x_prior inputs for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((x.shape[0], 1, 1),
device=x.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, 1), code_emb)
unconditioned_batches = torch.rand((x.shape[0], 1), device=x.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1), code_emb)
with torch.autocast(x.device.type, enabled=self.enable_fp16):
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
prior_time_emb = self.prior_time_embed(timestep_embedding(prior_timesteps, self.time_embed_dim))
res_emb = self.resolution_embed(resolution)
blk_emb = torch.cat([time_emb.unsqueeze(-1), prior_time_emb.unsqueeze(-1), res_emb.unsqueeze(-1), code_emb], dim=-1)
blk_emb = torch.cat([time_emb, prior_time_emb, res_emb, code_emb], dim=1)
h = torch.cat([x, x_prior], dim=1)
h = self.inp_block(h)
@ -304,5 +307,5 @@ def remove_conditioning(sd_path):
if __name__ == '__main__':
remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth')
#remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth')
test_tfd()

View File

@ -146,7 +146,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
# x = x.clamp(-s, s) / s
# return x
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm})
gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm}, eta=.8)
gen_mel_denorm = denormalize_torch_mel(gen_mel)
output_shape = (1,16,audio.shape[-1]//16)
@ -230,7 +230,6 @@ class MusicDiffusionFid(evaluator.Evaluator):
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_torch_mel(mel)
#mel_norm = mel_norm[:,:,:448*4] # restricts first stage to optimal training window.
conditioning = mel_norm[:,:,:1200]
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='nearest')
stage1_shape = (1, 256, downsampled.shape[-1]*4)
@ -323,19 +322,34 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
"""
# For multilevel SR:
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator',
also_load_savepoint=False, strict_load=False,
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\22000_generator.pth'
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\4000_generator.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
'diffusion_steps': 128, # basis: 192
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
'diffusion_schedule': 'linear', 'diffusion_type': 'chained_sr',
#'causal': True, 'causal_slope': 4,
#'partial_low': 128, 'partial_high': 192
'diffusion_schedule': 'cosine', 'diffusion_type': 'chained_sr',
}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}}
"""
# For TFD+cheater trainer
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater.yml', 'generator',
also_load_savepoint=False, strict_load=False,
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd14_and_cheater_g2\\models\\20000_generator.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
'diffusion_steps': 128, # basis: 192
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': True,
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant',
}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 6, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env)
fds = []
for i in range(2):