r2 of the flat diffusion

This commit is contained in:
James Betker 2022-03-21 11:40:43 -06:00
parent c5000420f6
commit 26dcf7f1a2
2 changed files with 355 additions and 114 deletions

View File

@ -0,0 +1,306 @@
import functools
import math
import random
from functools import partial
import torch
import torch.nn as nn
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, Rezero, \
exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \
AttentionLayers
class TimeIntegrationBlock(nn.Module):
def __init__(self, time_emb_dim, dim, normalizer):
super().__init__()
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
time_emb_dim,
2 * dim
),
)
self.normalizer = normalizer
def forward(self, x, time_emb):
emb_out = self.emb_layers(time_emb).type(x.dtype)
scale, shift = torch.chunk(emb_out, 2, dim=1)
x = self.normalizer(x)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class CheckpointedLayer(nn.Module):
"""
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
checkpoint for all other args.
"""
def __init__(self, wrap):
super().__init__()
self.wrap = wrap
def forward(self, x, *args, **kwargs):
for k, v in kwargs.items():
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
partial = functools.partial(self.wrap, **kwargs)
return torch.utils.checkpoint.checkpoint(partial, x, *args)
class TimestepEmbeddingAttentionLayers(AttentionLayers):
"""
Modification of x-transformers.AttentionLayers that performs checkpointing, timestep embeddings and layerdrop.
"""
def __init__(
self,
dim,
timestep_dim,
depth,
heads = 8,
causal = False,
cross_attend = False,
only_cross = False,
use_scalenorm = False,
use_rmsnorm = False,
use_rezero = False,
alibi_pos_bias = False,
alibi_num_heads = None,
alibi_learned = False,
rel_pos_bias = False,
rel_pos_num_buckets = 32,
rel_pos_max_distance = 128,
position_infused_attn = False,
rotary_pos_emb = False,
rotary_emb_dim = None,
custom_layers = None,
sandwich_coef = None,
par_ratio = None,
residual_attn = False,
cross_residual_attn = False,
macaron = False,
gate_residual = False,
scale_residual = False,
shift_tokens = 0,
use_qk_norm_attn = False,
qk_norm_attn_seq_len = None,
zero_init_branch_output = False,
layerdrop_percent = .1,
**kwargs
):
super().__init__(dim, depth)
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
self.dim = dim
self.depth = depth
self.layers = nn.ModuleList([])
self.layerdrop_percent = layerdrop_percent
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
if rel_pos_bias:
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
elif alibi_pos_bias:
alibi_num_heads = default(alibi_num_heads, heads)
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, bidirectional = not causal)
else:
self.rel_pos = None
self.residual_attn = residual_attn
self.cross_residual_attn = cross_residual_attn
self.cross_attend = cross_attend
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
norm_class = RMSNorm if use_rmsnorm else norm_class
norm_fn = partial(norm_class, dim)
norm_fn = nn.Identity if use_rezero else norm_fn
branch_fn = Rezero if use_rezero else None
if cross_attend and not only_cross:
default_block = ('a', 'c', 'f')
elif cross_attend and only_cross:
default_block = ('c', 'f')
else:
default_block = ('a', 'f')
if macaron:
default_block = ('f',) + default_block
# qk normalization
if use_qk_norm_attn:
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
# zero init
if zero_init_branch_output:
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
# calculate layer block order
if exists(custom_layers):
layer_types = custom_layers
elif exists(par_ratio):
par_depth = depth * len(default_block)
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block))
par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (par_width - len(default_block))
par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
else:
layer_types = default_block * depth
self.layer_types = layer_types
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
# calculate token shifting
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
# iterate and construct layers
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
if layer_type == 'a':
layer = CheckpointedLayer(Attention(dim, heads = heads, causal = causal, **attn_kwargs))
elif layer_type == 'c':
layer = CheckpointedLayer(Attention(dim, heads = heads, **attn_kwargs))
elif layer_type == 'f':
layer = CheckpointedLayer(FeedForward(dim, **ff_kwargs))
layer = layer if not macaron else Scale(0.5, layer)
else:
raise Exception(f'invalid layer type {layer_type}')
if layer_shift_tokens > 0:
shift_range_upper = layer_shift_tokens + 1
shift_range_lower = -layer_shift_tokens if not causal else 0
layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
if exists(branch_fn):
layer = branch_fn(layer)
residual_fn = GRUGating if gate_residual else Residual
residual = residual_fn(dim, scale_residual = scale_residual)
layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
pre_branch_norm = TimeIntegrationBlock(timestep_dim, dim, norm_fn())
post_branch_norm = norm_fn() if layer_uses_qk_norm else None
post_main_norm = None # Always do prenorm for timestep integration.
norms = nn.ModuleList([
pre_branch_norm,
post_branch_norm,
post_main_norm
])
self.layers.append(nn.ModuleList([
norms,
layer,
residual
]))
def forward(
self,
x,
time_emb = None,
context = None,
mask = None,
context_mask = None,
attn_mask = None,
mems = None,
return_hiddens = False
):
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
assert time_emb is not None, 'must specify a timestep embedding.'
hiddens = []
intermediates = []
prev_attn = None
prev_cross_attn = None
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
rotary_pos_emb = None
if exists(self.rotary_pos_emb):
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
unused_params = []
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
is_last = ind == (len(self.layers) - 1)
# Do layer drop where applicable. Do not drop first and last layers.
if self.training and self.layerdrop_percent > 0 and not is_last and ind != 0 and random.random() < self.layerdrop_percent:
# Record the unused parameters so they can be used in null-operations later to not trigger DDP.
unused_params.extend(list(block.parameters()))
unused_params.extend(list(residual_fn.parameters()))
unused_params.extend(list(norm.parameters()))
continue
if layer_type == 'a':
hiddens.append(x)
layer_mem = mems.pop(0) if mems else None
residual = x
pre_branch_norm, post_branch_norm, post_main_norm = norm
x = pre_branch_norm(x, time_emb)
if layer_type == 'a':
out, inter = block(x, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
elif layer_type == 'c':
out, inter = block(x, context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
elif layer_type == 'f':
out = block(x)
if exists(post_branch_norm):
out = post_branch_norm(out)
x = residual_fn(out, residual)
if layer_type in ('a', 'c'):
intermediates.append(inter)
if layer_type == 'a' and self.residual_attn:
prev_attn = inter.pre_softmax_attn
elif layer_type == 'c' and self.cross_residual_attn:
prev_cross_attn = inter.pre_softmax_attn
if exists(post_main_norm):
x = post_main_norm(x)
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
extraneous_addition = 0
for p in unused_params:
extraneous_addition = extraneous_addition + p.mean()
x = x + extraneous_addition * 0
if return_hiddens:
intermediates = LayerIntermediates(
hiddens = hiddens,
attn_intermediates = intermediates
)
return x, intermediates
return x

View File

@ -6,6 +6,7 @@ import torch.nn.functional as F
from torch import autocast
from x_transformers import Encoder
from models.audio.tts.diffusion_encoder import TimestepEmbeddingAttentionLayers
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
Downsample, Upsample, TimestepBlock
@ -23,94 +24,6 @@ def is_sequence(t):
return t.dtype == torch.long
class ResBlock(TimestepBlock):
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
dims=2,
kernel_size=3,
efficient_config=True,
use_scale_shift_norm=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_scale_shift_norm = use_scale_shift_norm
padding = {1: 0, 3: 1, 5: 2}[kernel_size]
eff_kernel = 1 if efficient_config else 3
eff_padding = 0 if efficient_config else 1
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, x, emb
)
def _forward(self, x, emb):
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class DiffusionLayer(nn.Module):
def __init__(self, model_channels, dropout, num_heads):
super().__init__()
self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
self.attn = AttentionBlock(model_channels, num_heads)
def forward(self, x, time_emb):
y = self.resblk(x, time_emb)
return self.attn(y)
class DiffusionTtsFlat(nn.Module):
def __init__(
self,
@ -120,7 +33,6 @@ class DiffusionTtsFlat(nn.Module):
in_latent_channels=512,
in_tokens=8193,
max_timesteps=4000,
max_positions=4000,
out_channels=200, # mean and variance
dropout=0,
use_fp16=False,
@ -140,9 +52,13 @@ class DiffusionTtsFlat(nn.Module):
self.enable_fp16 = use_fp16
self.layer_drop = layer_drop
self.inp_block = conv_nd(1, in_channels, model_channels//2, 3, 1, 1)
self.position_embed = nn.Embedding(max_positions, model_channels//2)
self.time_embed = nn.Embedding(max_timesteps, model_channels)
self.inp_block = nn.Conv1d(in_channels, model_channels, kernel_size=3, padding=1)
time_embed_dim = model_channels
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
@ -189,15 +105,42 @@ class DiffusionTtsFlat(nn.Module):
attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5)
self.conditioning_conv = nn.Conv1d(model_channels*2, model_channels, 1)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
self.conditioning_timestep_integrator = TimestepEmbedSequential(
ResBlock(model_channels, model_channels, dropout, out_channels=model_channels, dims=1, kernel_size=1, use_scale_shift_norm=True),
AttentionBlock(model_channels, num_heads=num_heads),
ResBlock(model_channels, model_channels, dropout, out_channels=model_channels, dims=1, kernel_size=1, use_scale_shift_norm=True),
AttentionBlock(model_channels, num_heads=num_heads),
ResBlock(model_channels, model_channels, dropout, out_channels=model_channels//2, dims=1, kernel_size=1, use_scale_shift_norm=True),
)
self.conditioning_timestep_integrator = CheckpointedXTransformerEncoder(
needs_permute=True,
max_seq_len=-1,
use_pos_emb=False,
attn_layers=TimestepEmbeddingAttentionLayers(
dim=model_channels,
timestep_dim=time_embed_dim,
depth=3,
heads=num_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_emb_dim=True,
layerdrop_percent=0,
)
)
self.integrate_conditioning = nn.Conv1d(model_channels*2, model_channels, 1)
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)])
self.layers = CheckpointedXTransformerEncoder(
needs_permute=True,
max_seq_len=-1,
use_pos_emb=False,
attn_layers=TimestepEmbeddingAttentionLayers(
dim=model_channels,
timestep_dim=time_embed_dim,
depth=num_layers,
heads=num_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_emb_dim=True,
layerdrop_percent=layer_drop,
)
)
self.out = nn.Sequential(
normalization(model_channels),
@ -253,20 +196,12 @@ class DiffusionTtsFlat(nn.Module):
code_emb)
# Everything after this comment is timestep dependent.
time_emb = self.time_embed(timesteps)
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
pos_emb = self.position_embed(torch.arange(0, x.shape[-1], device=x.device)).unsqueeze(0).repeat(x.shape[0],1,1).permute(0,2,1)
x = self.inp_block(x) + pos_emb
x = torch.cat([x, F.interpolate(code_emb, size=x.shape[-1], mode='nearest')], dim=1)
for i, lyr in enumerate(self.layers):
# Do layer drop where applicable. Do not drop first and last layers.
if self.training and self.layer_drop > 0 and i != 0 and i != (len(self.layers)-1) and random.random() < self.layer_drop:
unused_params.extend(list(lyr.parameters()))
else:
# First and last blocks will have autocast disabled for improved precision.
with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
x = lyr(x, time_emb)
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb=time_emb)
x = self.inp_block(x)
x = self.integrate_conditioning(torch.cat([x, F.interpolate(code_emb, size=x.shape[-1], mode='nearest')], dim=1))
with torch.autocast(x.device.type, enabled=self.enable_fp16):
x = self.layers(x, time_emb=time_emb)
x = x.float()
out = self.out(x)