From 26dcf7f1a2b3b8a28a34ba03914932e892992cb3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 21 Mar 2022 11:40:43 -0600 Subject: [PATCH] r2 of the flat diffusion --- codes/models/audio/tts/diffusion_encoder.py | 306 ++++++++++++++++++ .../audio/tts/unet_diffusion_tts_flat.py | 163 +++------- 2 files changed, 355 insertions(+), 114 deletions(-) create mode 100644 codes/models/audio/tts/diffusion_encoder.py diff --git a/codes/models/audio/tts/diffusion_encoder.py b/codes/models/audio/tts/diffusion_encoder.py new file mode 100644 index 00000000..b408e6fc --- /dev/null +++ b/codes/models/audio/tts/diffusion_encoder.py @@ -0,0 +1,306 @@ +import functools +import math +import random +from functools import partial + +import torch +import torch.nn as nn +from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \ + DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, Rezero, \ + exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \ + AttentionLayers + + +class TimeIntegrationBlock(nn.Module): + def __init__(self, time_emb_dim, dim, normalizer): + super().__init__() + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + time_emb_dim, + 2 * dim + ), + ) + self.normalizer = normalizer + + def forward(self, x, time_emb): + emb_out = self.emb_layers(time_emb).type(x.dtype) + scale, shift = torch.chunk(emb_out, 2, dim=1) + x = self.normalizer(x) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class CheckpointedLayer(nn.Module): + """ + Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses + checkpoint for all other args. + """ + def __init__(self, wrap): + super().__init__() + self.wrap = wrap + + def forward(self, x, *args, **kwargs): + for k, v in kwargs.items(): + assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. + partial = functools.partial(self.wrap, **kwargs) + return torch.utils.checkpoint.checkpoint(partial, x, *args) + + +class TimestepEmbeddingAttentionLayers(AttentionLayers): + """ + Modification of x-transformers.AttentionLayers that performs checkpointing, timestep embeddings and layerdrop. + """ + def __init__( + self, + dim, + timestep_dim, + depth, + heads = 8, + causal = False, + cross_attend = False, + only_cross = False, + use_scalenorm = False, + use_rmsnorm = False, + use_rezero = False, + alibi_pos_bias = False, + alibi_num_heads = None, + alibi_learned = False, + rel_pos_bias = False, + rel_pos_num_buckets = 32, + rel_pos_max_distance = 128, + position_infused_attn = False, + rotary_pos_emb = False, + rotary_emb_dim = None, + custom_layers = None, + sandwich_coef = None, + par_ratio = None, + residual_attn = False, + cross_residual_attn = False, + macaron = False, + gate_residual = False, + scale_residual = False, + shift_tokens = 0, + use_qk_norm_attn = False, + qk_norm_attn_seq_len = None, + zero_init_branch_output = False, + layerdrop_percent = .1, + **kwargs + ): + super().__init__(dim, depth) + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + self.layerdrop_percent = layerdrop_percent + + self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + + rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) + self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None + + assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + + if rel_pos_bias: + self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance) + elif alibi_pos_bias: + alibi_num_heads = default(alibi_num_heads, heads) + assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' + alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias + self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, bidirectional = not causal) + else: + self.rel_pos = None + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + self.cross_attend = cross_attend + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + # qk normalization + + if use_qk_norm_attn: + attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None + attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value} + + # zero init + + if zero_init_branch_output: + attn_kwargs = {**attn_kwargs, 'zero_init_output': True} + ff_kwargs = {**ff_kwargs, 'zero_init_output': True} + + # calculate layer block order + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + # calculate token shifting + + shift_tokens = cast_tuple(shift_tokens, len(layer_types)) + + # iterate and construct layers + for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): + if layer_type == 'a': + layer = CheckpointedLayer(Attention(dim, heads = heads, causal = causal, **attn_kwargs)) + elif layer_type == 'c': + layer = CheckpointedLayer(Attention(dim, heads = heads, **attn_kwargs)) + elif layer_type == 'f': + layer = CheckpointedLayer(FeedForward(dim, **ff_kwargs)) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if layer_shift_tokens > 0: + shift_range_upper = layer_shift_tokens + 1 + shift_range_lower = -layer_shift_tokens if not causal else 0 + layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer) + + if exists(branch_fn): + layer = branch_fn(layer) + + residual_fn = GRUGating if gate_residual else Residual + residual = residual_fn(dim, scale_residual = scale_residual) + + layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c') + + pre_branch_norm = TimeIntegrationBlock(timestep_dim, dim, norm_fn()) + post_branch_norm = norm_fn() if layer_uses_qk_norm else None + post_main_norm = None # Always do prenorm for timestep integration. + + norms = nn.ModuleList([ + pre_branch_norm, + post_branch_norm, + post_main_norm + ]) + + self.layers.append(nn.ModuleList([ + norms, + layer, + residual + ])) + + def forward( + self, + x, + time_emb = None, + context = None, + mask = None, + context_mask = None, + attn_mask = None, + mems = None, + return_hiddens = False + ): + assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True' + assert time_emb is not None, 'must specify a timestep embedding.' + + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + rotary_pos_emb = None + if exists(self.rotary_pos_emb): + max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems))) + rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) + + unused_params = [] + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + # Do layer drop where applicable. Do not drop first and last layers. + if self.training and self.layerdrop_percent > 0 and not is_last and ind != 0 and random.random() < self.layerdrop_percent: + # Record the unused parameters so they can be used in null-operations later to not trigger DDP. + unused_params.extend(list(block.parameters())) + unused_params.extend(list(residual_fn.parameters())) + unused_params.extend(list(norm.parameters())) + continue + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) if mems else None + + residual = x + + pre_branch_norm, post_branch_norm, post_main_norm = norm + + x = pre_branch_norm(x, time_emb) + + if layer_type == 'a': + out, inter = block(x, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem) + elif layer_type == 'c': + out, inter = block(x, context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + if exists(post_branch_norm): + out = post_branch_norm(out) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if exists(post_main_norm): + x = post_main_norm(x) + + # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. + extraneous_addition = 0 + for p in unused_params: + extraneous_addition = extraneous_addition + p.mean() + x = x + extraneous_addition * 0 + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens = hiddens, + attn_intermediates = intermediates + ) + + return x, intermediates + + return x \ No newline at end of file diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat.py b/codes/models/audio/tts/unet_diffusion_tts_flat.py index c910144d..b7c91031 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from torch import autocast from x_transformers import Encoder +from models.audio.tts.diffusion_encoder import TimestepEmbeddingAttentionLayers from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ Downsample, Upsample, TimestepBlock @@ -23,94 +24,6 @@ def is_sequence(t): return t.dtype == torch.long -class ResBlock(TimestepBlock): - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - dims=2, - kernel_size=3, - efficient_config=True, - use_scale_shift_norm=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_scale_shift_norm = use_scale_shift_norm - padding = {1: 0, 3: 1, 5: 2}[kernel_size] - eff_kernel = 1 if efficient_config else 3 - eff_padding = 0 if efficient_config else 1 - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding), - ) - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, x, emb - ) - - def _forward(self, x, emb): - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = torch.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class DiffusionLayer(nn.Module): - def __init__(self, model_channels, dropout, num_heads): - super().__init__() - self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True) - self.attn = AttentionBlock(model_channels, num_heads) - - def forward(self, x, time_emb): - y = self.resblk(x, time_emb) - return self.attn(y) - - class DiffusionTtsFlat(nn.Module): def __init__( self, @@ -120,7 +33,6 @@ class DiffusionTtsFlat(nn.Module): in_latent_channels=512, in_tokens=8193, max_timesteps=4000, - max_positions=4000, out_channels=200, # mean and variance dropout=0, use_fp16=False, @@ -140,9 +52,13 @@ class DiffusionTtsFlat(nn.Module): self.enable_fp16 = use_fp16 self.layer_drop = layer_drop - self.inp_block = conv_nd(1, in_channels, model_channels//2, 3, 1, 1) - self.position_embed = nn.Embedding(max_positions, model_channels//2) - self.time_embed = nn.Embedding(max_timesteps, model_channels) + self.inp_block = nn.Conv1d(in_channels, model_channels, kernel_size=3, padding=1) + time_embed_dim = model_channels + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally @@ -189,15 +105,42 @@ class DiffusionTtsFlat(nn.Module): attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5) self.conditioning_conv = nn.Conv1d(model_channels*2, model_channels, 1) self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1)) - self.conditioning_timestep_integrator = TimestepEmbedSequential( - ResBlock(model_channels, model_channels, dropout, out_channels=model_channels, dims=1, kernel_size=1, use_scale_shift_norm=True), - AttentionBlock(model_channels, num_heads=num_heads), - ResBlock(model_channels, model_channels, dropout, out_channels=model_channels, dims=1, kernel_size=1, use_scale_shift_norm=True), - AttentionBlock(model_channels, num_heads=num_heads), - ResBlock(model_channels, model_channels, dropout, out_channels=model_channels//2, dims=1, kernel_size=1, use_scale_shift_norm=True), - ) + self.conditioning_timestep_integrator = CheckpointedXTransformerEncoder( + needs_permute=True, + max_seq_len=-1, + use_pos_emb=False, + attn_layers=TimestepEmbeddingAttentionLayers( + dim=model_channels, + timestep_dim=time_embed_dim, + depth=3, + heads=num_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_emb_dim=True, + layerdrop_percent=0, + ) + ) + self.integrate_conditioning = nn.Conv1d(model_channels*2, model_channels, 1) - self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)]) + self.layers = CheckpointedXTransformerEncoder( + needs_permute=True, + max_seq_len=-1, + use_pos_emb=False, + attn_layers=TimestepEmbeddingAttentionLayers( + dim=model_channels, + timestep_dim=time_embed_dim, + depth=num_layers, + heads=num_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_emb_dim=True, + layerdrop_percent=layer_drop, + ) + ) self.out = nn.Sequential( normalization(model_channels), @@ -253,20 +196,12 @@ class DiffusionTtsFlat(nn.Module): code_emb) # Everything after this comment is timestep dependent. - time_emb = self.time_embed(timesteps) - code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) - pos_emb = self.position_embed(torch.arange(0, x.shape[-1], device=x.device)).unsqueeze(0).repeat(x.shape[0],1,1).permute(0,2,1) - x = self.inp_block(x) + pos_emb - x = torch.cat([x, F.interpolate(code_emb, size=x.shape[-1], mode='nearest')], dim=1) - for i, lyr in enumerate(self.layers): - # Do layer drop where applicable. Do not drop first and last layers. - if self.training and self.layer_drop > 0 and i != 0 and i != (len(self.layers)-1) and random.random() < self.layer_drop: - unused_params.extend(list(lyr.parameters())) - else: - # First and last blocks will have autocast disabled for improved precision. - with autocast(x.device.type, enabled=self.enable_fp16 and i != 0): - x = lyr(x, time_emb) - + time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + code_emb = self.conditioning_timestep_integrator(code_emb, time_emb=time_emb) + x = self.inp_block(x) + x = self.integrate_conditioning(torch.cat([x, F.interpolate(code_emb, size=x.shape[-1], mode='nearest')], dim=1)) + with torch.autocast(x.device.type, enabled=self.enable_fp16): + x = self.layers(x, time_emb=time_emb) x = x.float() out = self.out(x)