rework tfd13 further

- use a gated activation layer for both attention & convs
- add a relativistic learned position bias. I believe this is similar to the T5 position encodings but it is simpler and learned
- get rid of prepending to the attention matrix - this doesn't really work that well. the model eventually learns to attend one of its heads to these blocks but why not just concat if it is doing that?
This commit is contained in:
James Betker 2022-07-20 23:28:29 -06:00
parent 40427de8e3
commit ee8ceed6da
3 changed files with 111 additions and 55 deletions

View File

@ -341,6 +341,20 @@ class Downsample(nn.Module):
return self.op(x) return self.op(x)
class cGLU(nn.Module):
"""
Gated GELU for channel-first architectures.
"""
def __init__(self, dim_in, dim_out=None):
super().__init__()
dim_out = dim_in if dim_out is None else dim_out
self.proj = nn.Conv1d(dim_in, dim_out * 2, 1)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=1)
return x * F.gelu(gate)
class ResBlock(nn.Module): class ResBlock(nn.Module):
""" """
A residual block that can optionally change the number of channels. A residual block that can optionally change the number of channels.
@ -439,7 +453,7 @@ class ResBlock(nn.Module):
return self.skip_connection(x) + h return self.skip_connection(x) + h
def build_local_attention_mask(n, l, fixed_region): def build_local_attention_mask(n, l, fixed_region=0):
""" """
Builds an attention mask that focuses attention on local region Builds an attention mask that focuses attention on local region
Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied. Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied.
@ -465,6 +479,24 @@ def test_local_attention_mask():
print(build_local_attention_mask(9,4,1)) print(build_local_attention_mask(9,4,1))
class RelativeQKBias(nn.Module):
"""
Very simple relative position bias scheme which should be directly added to QK matrix. This bias simply applies to
the distance from the given element.
"""
def __init__(self, l, max_positions=4000):
super().__init__()
self.emb = nn.Parameter(torch.randn(l+1) * .01)
o = torch.arange(0,max_positions)
c = o.unsqueeze(-1).repeat(1,max_positions)
r = o.unsqueeze(0).repeat(max_positions,1)
M = ((-(r-c).abs())+l).clamp(0,l)
self.register_buffer('M', M, persistent=False)
def forward(self, n):
return self.emb[self.M[:n, :n]].view(1,n,n)
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
""" """
An attention block that allows spatial positions to attend to each other. An attention block that allows spatial positions to attend to each other.
@ -507,16 +539,22 @@ class AttentionBlock(nn.Module):
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1) self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1)
self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1)) self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1))
def forward(self, x, mask=None): def forward(self, x, mask=None, qk_bias=None):
if self.do_checkpoint: if self.do_checkpoint:
if mask is not None: if mask is None:
if qk_bias is None:
return checkpoint(self._forward, x)
else:
assert False, 'unsupported: qk_bias but no mask'
else:
if qk_bias is None:
return checkpoint(self._forward, x, mask) return checkpoint(self._forward, x, mask)
else: else:
return checkpoint(self._forward, x) return checkpoint(self._forward, x, mask, qk_bias)
else: else:
return self._forward(x, mask) return self._forward(x, mask)
def _forward(self, x, mask=None): def _forward(self, x, mask=None, qk_bias=0):
b, c, *spatial = x.shape b, c, *spatial = x.shape
if mask is not None: if mask is not None:
if len(mask.shape) == 2: if len(mask.shape) == 2:
@ -529,7 +567,7 @@ class AttentionBlock(nn.Module):
if self.do_activation: if self.do_activation:
x = F.silu(x, inplace=True) x = F.silu(x, inplace=True)
qkv = self.qkv(x) qkv = self.qkv(x)
h = self.attention(qkv, mask) h = self.attention(qkv, mask, qk_bias)
h = self.proj_out(h) h = self.proj_out(h)
xp = self.x_proj(x) xp = self.x_proj(x)
return (xp + h).reshape(b, xp.shape[1], *spatial) return (xp + h).reshape(b, xp.shape[1], *spatial)
@ -544,7 +582,7 @@ class QKVAttentionLegacy(nn.Module):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
def forward(self, qkv, mask=None): def forward(self, qkv, mask=None, qk_bias=0):
""" """
Apply QKV attention. Apply QKV attention.
@ -559,6 +597,7 @@ class QKVAttentionLegacy(nn.Module):
weight = torch.einsum( weight = torch.einsum(
"bct,bcs->bts", q * scale, k * scale "bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards ) # More stable with f16 than dividing afterwards
weight = weight + qk_bias
if mask is not None: if mask is not None:
mask = mask.repeat(self.n_heads, 1, 1) mask = mask.repeat(self.n_heads, 1, 1)
weight[mask.logical_not()] = -torch.inf weight[mask.logical_not()] = -torch.inf
@ -577,7 +616,7 @@ class QKVAttention(nn.Module):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
def forward(self, qkv, mask=None): def forward(self, qkv, mask=None, qk_bias=0):
""" """
Apply QKV attention. Apply QKV attention.

View File

@ -6,7 +6,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \
RelativeQKBias
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock from models.diffusion.unet_diffusion import TimestepBlock
from trainer.networks import register_model from trainer.networks import register_model
@ -22,73 +23,73 @@ def is_sequence(t):
class SubBlock(nn.Module): class SubBlock(nn.Module):
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout): def __init__(self, inp_dim, contraction_dim, heads, dropout):
super().__init__() super().__init__()
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
self.blk_emb_proj = nn.Conv1d(blk_dim, inp_dim, 1)
self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads) self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads)
self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False)
self.pos_bias = RelativeQKBias(l=64)
self.attn_glu = cGLU(contraction_dim)
self.attnorm = nn.GroupNorm(8, contraction_dim) self.attnorm = nn.GroupNorm(8, contraction_dim)
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1) self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
self.ff_glu = cGLU(contraction_dim)
self.ffnorm = nn.GroupNorm(8, contraction_dim) self.ffnorm = nn.GroupNorm(8, contraction_dim)
self.mask = build_local_attention_mask(n=4000, l=64, fixed_region=8)
self.mask_initialized = False
def forward(self, x, blk_emb): def forward(self, x):
if self.mask is not None and not self.mask_initialized: ah = self.dropout(self.attn(x, mask=self.mask, qk_bias=self.pos_bias(x.shape[-1])))
self.mask = self.mask.to(x.device) ah = self.attn_glu(self.attnorm(ah))
self.mask_initialized = True
blk_enc = self.blk_emb_proj(blk_emb)
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
ah = ah[:,:,blk_enc.shape[-1]:] # Strip off the blk_emc used for attention and re-align with x.
ah = F.gelu(self.attnorm(ah))
h = torch.cat([ah, x], dim=1) h = torch.cat([ah, x], dim=1)
hf = self.dropout(checkpoint(self.ff, h)) hf = self.dropout(checkpoint(self.ff, h))
hf = F.gelu(self.ffnorm(hf)) hf = self.ff_glu(self.ffnorm(hf))
h = torch.cat([h, hf], dim=1) h = torch.cat([h, hf], dim=1)
return h return h
class ConcatAttentionBlock(TimestepBlock): class ConcatAttentionBlock(TimestepBlock):
def __init__(self, trunk_dim, contraction_dim, heads, dropout): def __init__(self, trunk_dim, contraction_dim, blk_dim, heads, dropout):
super().__init__() super().__init__()
self.contraction_dim = contraction_dim
self.prenorm = nn.GroupNorm(8, trunk_dim) self.prenorm = nn.GroupNorm(8, trunk_dim)
self.block1 = SubBlock(trunk_dim, contraction_dim, trunk_dim, heads, dropout) self.block1 = SubBlock(trunk_dim+blk_dim, contraction_dim, heads, dropout)
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, trunk_dim, heads, dropout) self.block2 = SubBlock(trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout)
self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False) self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False)
self.out.weight.data.zero_() self.out.weight.data.zero_()
def forward(self, x, blk_emb): def forward(self, x, blk_emb):
h = self.prenorm(x) h = self.prenorm(x)
h = self.block1(h, blk_emb) h = torch.cat([h, blk_emb.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1)
h = self.block2(h, blk_emb) h = self.block1(h)
h = self.out(h[:,x.shape[1]:]) h = self.block2(h)
h = self.out(h[:,-self.contraction_dim*4:])
return h + x return h + x
class ConditioningEncoder(nn.Module): class ConditioningEncoder(nn.Module):
def __init__(self, def __init__(self,
spec_dim, spec_dim,
embedding_dim, hidden_dim,
out_dim,
num_resolutions, num_resolutions,
attn_blocks=6, attn_blocks=6,
num_attn_heads=4, num_attn_heads=4,
do_checkpointing=False): do_checkpointing=False):
super().__init__() super().__init__()
attn = [] attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=5, stride=2) self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2)
self.resolution_embedding = nn.Embedding(num_resolutions, embedding_dim) self.resolution_embedding = nn.Embedding(num_resolutions, hidden_dim)
self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start. self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start.
for a in range(attn_blocks): for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing))
attn.append(ResBlock(embedding_dim, dims=1, checkpointing_enabled=do_checkpointing)) attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing))
self.attn = nn.Sequential(*attn) self.attn = nn.Sequential(*attn)
self.dim = embedding_dim self.out = nn.Linear(hidden_dim, out_dim, bias=False)
self.dim = hidden_dim
self.do_checkpointing = do_checkpointing self.do_checkpointing = do_checkpointing
def forward(self, x, resolution): def forward(self, x, resolution):
h = self.init(x) + self.resolution_embedding(resolution).unsqueeze(-1) h = self.init(x) + self.resolution_embedding(resolution).unsqueeze(-1)
h = self.attn(h) h = self.attn(h)
return h[:, :, :5] return self.out(h[:, :, 0])
class TransformerDiffusion(nn.Module): class TransformerDiffusion(nn.Module):
@ -97,7 +98,6 @@ class TransformerDiffusion(nn.Module):
""" """
def __init__( def __init__(
self, self,
time_embed_dim=256,
resolution_steps=8, resolution_steps=8,
max_window=384, max_window=384,
model_channels=1024, model_channels=1024,
@ -106,6 +106,9 @@ class TransformerDiffusion(nn.Module):
in_channels=256, in_channels=256,
input_vec_dim=1024, input_vec_dim=1024,
out_channels=512, # mean and variance out_channels=512, # mean and variance
time_embed_dim=256,
time_proj_dim=64,
cond_proj_dim=256,
num_heads=4, num_heads=4,
dropout=0, dropout=0,
use_fp16=False, use_fp16=False,
@ -128,19 +131,20 @@ class TransformerDiffusion(nn.Module):
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, model_channels), linear(time_embed_dim, time_proj_dim),
) )
self.prior_time_embed = nn.Sequential( self.prior_time_embed = nn.Sequential(
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, model_channels), linear(time_embed_dim, time_proj_dim),
) )
self.resolution_embed = nn.Embedding(resolution_steps, model_channels) self.resolution_embed = nn.Embedding(resolution_steps, time_proj_dim)
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, resolution_steps, num_attn_heads=model_channels//64) self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,5)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))
self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1) self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1)
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)]) self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim,
num_heads, dropout) for _ in range(num_layers)])
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(model_channels), normalization(model_channels),
@ -246,15 +250,14 @@ class TransformerDiffusion(nn.Module):
# Mask out the conditioning input and x_prior inputs for whole batch elements, implementing something similar to classifier-free guidance. # Mask out the conditioning input and x_prior inputs for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0: if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((x.shape[0], 1, 1), unconditioned_batches = torch.rand((x.shape[0], 1), device=x.device) < self.unconditioned_percentage
device=x.device) < self.unconditioned_percentage code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1), code_emb)
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, 1), code_emb)
with torch.autocast(x.device.type, enabled=self.enable_fp16): with torch.autocast(x.device.type, enabled=self.enable_fp16):
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
prior_time_emb = self.prior_time_embed(timestep_embedding(prior_timesteps, self.time_embed_dim)) prior_time_emb = self.prior_time_embed(timestep_embedding(prior_timesteps, self.time_embed_dim))
res_emb = self.resolution_embed(resolution) res_emb = self.resolution_embed(resolution)
blk_emb = torch.cat([time_emb.unsqueeze(-1), prior_time_emb.unsqueeze(-1), res_emb.unsqueeze(-1), code_emb], dim=-1) blk_emb = torch.cat([time_emb, prior_time_emb, res_emb, code_emb], dim=1)
h = torch.cat([x, x_prior], dim=1) h = torch.cat([x, x_prior], dim=1)
h = self.inp_block(h) h = self.inp_block(h)
@ -304,5 +307,5 @@ def remove_conditioning(sd_path):
if __name__ == '__main__': if __name__ == '__main__':
remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth') #remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth')
test_tfd() test_tfd()

View File

@ -146,7 +146,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
# x = x.clamp(-s, s) / s # x = x.clamp(-s, s) / s
# return x # return x
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm}) gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm}, eta=.8)
gen_mel_denorm = denormalize_torch_mel(gen_mel) gen_mel_denorm = denormalize_torch_mel(gen_mel)
output_shape = (1,16,audio.shape[-1]//16) output_shape = (1,16,audio.shape[-1]//16)
@ -230,7 +230,6 @@ class MusicDiffusionFid(evaluator.Evaluator):
audio = audio.unsqueeze(0) audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out'] mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_torch_mel(mel) mel_norm = normalize_torch_mel(mel)
#mel_norm = mel_norm[:,:,:448*4] # restricts first stage to optimal training window.
conditioning = mel_norm[:,:,:1200] conditioning = mel_norm[:,:,:1200]
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='nearest') downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='nearest')
stage1_shape = (1, 256, downsampled.shape[-1]*4) stage1_shape = (1, 256, downsampled.shape[-1]*4)
@ -323,19 +322,34 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__': if __name__ == '__main__':
"""
# For multilevel SR:
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator', diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator',
also_load_savepoint=False, strict_load=False, also_load_savepoint=False, strict_load=False,
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\22000_generator.pth' load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\4000_generator.pth'
).cuda() ).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
'diffusion_steps': 128, # basis: 192 'diffusion_steps': 128, # basis: 192
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False, 'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
'diffusion_schedule': 'linear', 'diffusion_type': 'chained_sr', 'diffusion_schedule': 'cosine', 'diffusion_type': 'chained_sr',
#'causal': True, 'causal_slope': 4,
#'partial_low': 128, 'partial_high': 192
} }
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}} """
# For TFD+cheater trainer
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater.yml', 'generator',
also_load_savepoint=False, strict_load=False,
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd14_and_cheater_g2\\models\\20000_generator.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
'diffusion_steps': 128, # basis: 192
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': True,
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant',
}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 6, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env) eval = MusicDiffusionFid(diffusion, opt_eval, env)
fds = [] fds = []
for i in range(2): for i in range(2):