hopefully this helps address the positional dependencies of tfd12
This commit is contained in:
James Betker 2022-07-19 13:30:05 -06:00
parent 4597447178
commit b157b28c7b
3 changed files with 314 additions and 8 deletions

View File

@ -492,6 +492,11 @@ class AttentionBlock(nn.Module):
def _forward(self, x, mask=None): def _forward(self, x, mask=None):
b, c, *spatial = x.shape b, c, *spatial = x.shape
if len(mask.shape) == 2:
mask = mask.unsqueeze(0).repeat(x.shape[0],1,1)
if mask.shape[1] != x.shape[-1]:
mask = mask[:, :x.shape[-1], :x.shape[-1]]
x = x.reshape(b, c, -1) x = x.reshape(b, c, -1)
x = self.norm(x) x = self.norm(x)
if self.do_activation: if self.do_activation:
@ -527,11 +532,10 @@ class QKVAttentionLegacy(nn.Module):
weight = torch.einsum( weight = torch.einsum(
"bct,bcs->bts", q * scale, k * scale "bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards ) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None: if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. mask = mask.repeat(self.n_heads, 1, 1)
mask = mask.repeat(self.n_heads, 1).unsqueeze(1) weight[mask.logical_not()] = -torch.inf
weight = weight * mask weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v) a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
@ -564,9 +568,8 @@ class QKVAttention(nn.Module):
(k * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards ) # More stable with f16 than dividing afterwards
if mask is not None: if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. mask = mask.repeat(self.n_heads, 1, 1)
mask = mask.repeat(self.n_heads, 1).unsqueeze(1) weight[mask.logical_not()] = -torch.inf
weight = weight * mask
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)

View File

@ -0,0 +1,303 @@
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.arch_util import AttentionBlock, TimestepEmbedSequential
from models.audio.music.encoders import ResEncoder16x
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock
from trainer.networks import register_model
from utils.util import checkpoint, print_network
def build_local_attention_mask(n, l, fixed_region):
"""
Builds an attention mask that focuses attention on local region
Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied.
Args:
n: Size of returned matrix (maximum sequence size)
l: Size of local context (uni-directional, e.g. the total context is l*2)
fixed_region: The number of sequence elements at the start of the sequence that get full attention.
Returns:
A mask that can be applied to AttentionBlock to achieve local attention.
"""
assert l*2 < n, f'Local context must be less than global context. {l}, {n}'
o = torch.arange(0,n)
c = o.unsqueeze(-1).repeat(1,n)
r = o.unsqueeze(0).repeat(n,1)
localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1)
localized[:fixed_region] = 1
localized[:, :fixed_region] = 1
mask = localized > 0
return mask
def test_local_attention_mask():
print(build_local_attention_mask(9,4,1))
class SubBlock(nn.Module):
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout, enable_attention_masking=False):
super().__init__()
self.enable_attention_masking = enable_attention_masking
self.dropout = nn.Dropout(p=dropout)
self.blk_emb_proj = nn.Conv1d(blk_dim, inp_dim, 1)
self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads)
self.attnorm = nn.GroupNorm(8, contraction_dim)
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
self.ffnorm = nn.GroupNorm(8, contraction_dim)
if self.enable_attention_masking:
# All regions can attend to the first token, which will be the timestep embedding. Hence, fixed_region.
self.mask = build_local_attention_mask(n=2000, l=48, fixed_region=1)
else:
self.mask = None
def forward(self, x, blk_emb):
if self.mask is not None:
self.mask = self.mask.to(x.device)
blk_enc = self.blk_emb_proj(blk_emb)
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.
ah = F.gelu(self.attnorm(ah))
h = torch.cat([ah, x], dim=1)
hf = self.dropout(checkpoint(self.ff, h))
hf = F.gelu(self.ffnorm(hf))
h = torch.cat([h, hf], dim=1)
return h
class ConcatAttentionBlock(TimestepBlock):
def __init__(self, trunk_dim, contraction_dim, heads, dropout, enable_attention_masking=False):
super().__init__()
self.prenorm = nn.GroupNorm(8, trunk_dim)
self.block1 = SubBlock(trunk_dim, contraction_dim, trunk_dim, heads, dropout,
enable_attention_masking=enable_attention_masking)
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, trunk_dim, heads, dropout,
enable_attention_masking=enable_attention_masking)
self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False)
self.out.weight.data.zero_()
def forward(self, x, blk_emb):
h = self.prenorm(x)
h = self.block1(h, blk_emb)
h = self.block2(h, blk_emb)
h = self.out(h[:,x.shape[1]:])
return h + x
class TransformerDiffusion(nn.Module):
"""
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
"""
def __init__(
self,
time_embed_dim=256,
model_channels=1024,
contraction_dim=256,
num_layers=8,
in_channels=256,
input_vec_dim=1024,
out_channels=512, # mean and variance
num_heads=4,
dropout=0,
use_corner_alignment=False, # This is an interpolation parameter only provided for backwards compatibility. ALL NEW TRAINS SHOULD SET THIS TO TRUE.
use_fp16=False,
new_code_expansion=False,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# Parameters for re-training head
freeze_except_code_converters=False,
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.time_embed_dim = time_embed_dim
self.out_channels = out_channels
self.dropout = dropout
self.unconditioned_percentage = unconditioned_percentage
self.enable_fp16 = use_fp16
self.new_code_expansion = new_code_expansion
self.use_corner_alignment = use_corner_alignment
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
self.time_embed = nn.Sequential(
linear(time_embed_dim, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, model_channels),
)
self.input_converter = nn.Conv1d(input_vec_dim, model_channels, 1)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
self.intg = nn.Conv1d(model_channels*2, model_channels, 1)
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout, enable_attention_masking=True) for _ in range(num_layers)])
self.out = nn.Sequential(
normalization(model_channels),
nn.SiLU(),
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
)
if freeze_except_code_converters:
for p in self.parameters():
p.DO_NOT_TRAIN = True
p.requires_grad = False
for m in [self.code_converter and self.input_converter]:
for p in m.parameters():
del p.DO_NOT_TRAIN
p.requires_grad = True
def get_grad_norm_parameter_groups(self):
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.layers]))
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.layers]))
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
groups = {
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
'blk1_attention_layers': attn1,
'blk2_attention_layers': attn2,
'attention_layers': attn1 + attn2,
'blk1_ff_layers': ff1,
'blk2_ff_layers': ff2,
'ff_layers': ff1 + ff2,
'block_out_layers': blkout_layers,
'out': list(self.out.parameters()),
'x_proj': list(self.inp_block.parameters()),
'layers': list(self.layers.parameters()),
'time_embed': list(self.time_embed.parameters()),
}
return groups
def forward(self, x, timesteps, prior=None, conditioning_free=False):
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1)
else:
code_emb = self.input_converter(prior)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
code_emb)
code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest')
with torch.autocast(x.device.type, enabled=self.enable_fp16):
blk_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)).unsqueeze(-1)
x = self.inp_block(x)
x = self.intg(torch.cat([x, code_emb], dim=1))
for layer in self.layers:
x = checkpoint(layer, x, blk_emb)
x = x.float()
out = self.out(x)
# Defensively involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
unused_params = [self.unconditioned_embedding]
extraneous_addition = 0
for p in unused_params:
extraneous_addition = extraneous_addition + p.mean()
out = out + extraneous_addition * 0
return out
class TransformerDiffusionWithCheaterLatent(nn.Module):
def __init__(self, freeze_encoder_until=None, checkpoint_encoder=True, **kwargs):
super().__init__()
self.internal_step = 0
self.freeze_encoder_until = freeze_encoder_until
self.diff = TransformerDiffusion(**kwargs)
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
def forward(self, x, timesteps, truth_mel, conditioning_free=False):
unused_parameters = []
encoder_grad_enabled = self.freeze_encoder_until is not None and self.internal_step > self.freeze_encoder_until
if not encoder_grad_enabled:
unused_parameters.extend(list(self.encoder.parameters()))
with torch.set_grad_enabled(encoder_grad_enabled):
proj = self.encoder(truth_mel)
for p in unused_parameters:
proj = proj + p.mean() * 0
diff = self.diff(x, timesteps, prior=proj, conditioning_free=conditioning_free)
return diff
def get_debug_values(self, step, __):
self.internal_step = step
return {}
def get_grad_norm_parameter_groups(self):
groups = self.diff.get_grad_norm_parameter_groups()
groups['encoder'] = list(self.encoder.parameters())
return groups
def before_step(self, step):
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers]))
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
for p in scaled_grad_parameters:
if hasattr(p, 'grad') and p.grad is not None:
p.grad *= .2
@register_model
def register_transformer_diffusion14(opt_net, opt):
return TransformerDiffusion(**opt_net['kwargs'])
@register_model
def register_transformer_diffusion_14_with_cheater_latent(opt_net, opt):
return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs'])
def test_tfd():
clip = torch.randn(2,256,400)
ts = torch.LongTensor([600, 600])
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
num_heads=3, input_vec_dim=256, num_layers=12, dropout=.1)
model(clip, ts, clip)
def test_cheater_model():
clip = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
# For music:
model = TransformerDiffusionWithCheaterLatent(in_channels=256, out_channels=512,
model_channels=1024, contraction_dim=512, num_heads=8,
input_vec_dim=256, num_layers=16,
dropout=.1, new_code_expansion=True,
)
#diff_weights = torch.load('extracted_diff.pth')
#model.diff.load_state_dict(diff_weights, strict=False)
#model.encoder.load_state_dict(torch.load('../experiments/music_cheater_encoder_256.pth', map_location=torch.device('cpu')), strict=True)
#torch.save(model.state_dict(), 'sample.pth')
print_network(model)
o = model(clip, ts, clip)
pg = model.get_grad_norm_parameter_groups()
def extract_cheater_encoder(in_f, out_f):
p = torch.load(in_f)
out = {}
for k, v in p.items():
if k.startswith('encoder.'):
out[k] = v
torch.save(out, out_f)
if __name__ == '__main__':
#test_local_attention_mask()
extract_cheater_encoder('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\104500_generator_ema.pth', 'X:\\dlas\\experiments\\tfd12_self_learned_cheater_enc.pth', True)
test_cheater_model()
#extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)

View File

@ -80,7 +80,7 @@ class GaussianDiffusionInjector(Injector):
def forward(self, state): def forward(self, state):
gen = self.env['generators'][self.opt['generator']] gen = self.env['generators'][self.opt['generator']]
hq = state[self.input] hq = state[self.input]
assert hq.max() < 1 or hq.min() > -1, f"Attempting to train gaussian diffusion on un-normalized inputs. This won't work, silly! {hq.min()} {hq.max()}" assert hq.max() < 1.000001 or hq.min() > -1.00001, f"Attempting to train gaussian diffusion on un-normalized inputs. This won't work, silly! {hq.min()} {hq.max()}"
with autocast(enabled=self.env['opt']['fp16']): with autocast(enabled=self.env['opt']['fp16']):
if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0): if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0):