From 37bdfe82b2f1951752c0e605ef0048859b4771dc Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 6 Apr 2022 00:35:29 -0600 Subject: [PATCH] Modify x_transformers to do checkpointing and use relative positional biases --- .../audio/tts/autoregressive_codegen.py | 53 +++---------------- codes/models/lucidrains/x_transformers.py | 36 +++++++------ 2 files changed, 28 insertions(+), 61 deletions(-) diff --git a/codes/models/audio/tts/autoregressive_codegen.py b/codes/models/audio/tts/autoregressive_codegen.py index 6ed288b0..b820879b 100644 --- a/codes/models/audio/tts/autoregressive_codegen.py +++ b/codes/models/audio/tts/autoregressive_codegen.py @@ -1,17 +1,12 @@ -import functools - import torch import torch.nn as nn import torch.nn.functional as F from transformers import GPT2PreTrainedModel, GPT2Config from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions -from x_transformers import TransformerWrapper, Encoder, Decoder -from data.audio.voice_tokenizer import VoiceBpeTokenizer from models.arch_util import AttentionBlock -from scripts.audio.gen.speech_synthesis_utils import wav_to_mel +from models.lucidrains.x_transformers import TransformerWrapper, Encoder, Decoder from trainer.networks import register_model -from utils.util import load_audio class InferenceModel(GPT2PreTrainedModel): @@ -92,7 +87,7 @@ class InferenceModel(GPT2PreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True) - logits = self.transformer.decoder.transformer.to_logits(hidden_states) + logits = self.transformer.decoder.to_logits(hidden_states) if not return_dict: return (logits, ) @@ -161,40 +156,6 @@ class ConditioningEncoder(nn.Module): return h.mean(dim=2) -class CheckpointedLayer(nn.Module): - """ - Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses - checkpoint for all other args. - """ - def __init__(self, wrap): - super().__init__() - self.wrap = wrap - - def forward(self, x, *args, **kwargs): - for k, v in kwargs.items(): - assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. - partial = functools.partial(self.wrap, **kwargs) - return torch.utils.checkpoint.checkpoint(partial, x, *args) - - -class CheckpointedXTransformerWrapper(nn.Module): - """ - Wraps a TransformerWrapper and applies CheckpointedLayer to each layer. - """ - def __init__(self, checkpoint=True, **xtransformer_kwargs): - super().__init__() - self.transformer = TransformerWrapper(**xtransformer_kwargs) - - if not checkpoint: - return - for i in range(len(self.transformer.attn_layers.layers)): - n, b, r = self.transformer.attn_layers.layers[i] - self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) - - def forward(self, x, **kwargs): - return self.transformer(x, **kwargs) - - class AutoregressiveCodegen(nn.Module): def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1): super().__init__() @@ -204,7 +165,7 @@ class AutoregressiveCodegen(nn.Module): self.max_text_token_id = num_text_tokens self.max_mel_token_id = num_mel_tokens self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False) - self.encoder = CheckpointedXTransformerWrapper( + self.encoder = TransformerWrapper( num_tokens=num_text_tokens, use_pos_emb=False, max_seq_len=-1, @@ -218,9 +179,10 @@ class AutoregressiveCodegen(nn.Module): ff_glu=True, ff_mult=1, rotary_pos_emb=True, + attn_rel_pos_bias=True, )) - self.encoder.transformer.to_logits = nn.Identity() # This is unused. - self.decoder = CheckpointedXTransformerWrapper( + self.encoder.to_logits = nn.Identity() # This is unused. + self.decoder = TransformerWrapper( num_tokens=num_mel_tokens, use_pos_emb=False, max_seq_len=-1, @@ -235,6 +197,7 @@ class AutoregressiveCodegen(nn.Module): ff_mult=1, rotary_pos_emb=True, cross_attend=True, + attn_rel_pos_bias=True, )) def get_grad_norm_parameter_groups(self): @@ -289,7 +252,7 @@ class AutoregressiveCodegen(nn.Module): gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN, max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, **hf_generate_kwargs) - return gen + return gen.sequences @register_model diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index f3dd0551..c7dfc3c3 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -1,3 +1,4 @@ +import functools import math import torch from torch import nn, einsum @@ -14,6 +15,7 @@ from entmax import entmax15 from x_transformers.autoregressive_wrapper import AutoregressiveWrapper # constants +from utils.util import checkpoint DEFAULT_DIM_HEAD = 64 @@ -445,7 +447,10 @@ class Attention(nn.Module): zero_init_output = False, max_attend_past = None, qk_norm = False, - scale_init_value = None + scale_init_value = None, + rel_pos_bias = False, + rel_pos_num_buckets = 32, + rel_pos_max_distance = 128, ): super().__init__() self.scale = dim_head ** -0.5 @@ -508,6 +513,11 @@ class Attention(nn.Module): self.attn_on_attn = on_attn self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim) + self.rel_pos_bias = rel_pos_bias + if rel_pos_bias: + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance) + # init output projection 0 if zero_init_output: init_zero_(self.to_out) @@ -519,7 +529,6 @@ class Attention(nn.Module): mask = None, context_mask = None, attn_mask = None, - rel_pos = None, sinusoidal_emb = None, rotary_pos_emb = None, prev_attn = None, @@ -593,8 +602,8 @@ class Attention(nn.Module): if talking_heads: dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() - if exists(rel_pos): - dots = rel_pos(dots) + if self.rel_pos_bias: + dots = self.rel_pos(dots) if exists(input_mask): dots.masked_fill_(~input_mask, mask_value) @@ -673,9 +682,6 @@ class AttentionLayers(nn.Module): alibi_pos_bias = False, alibi_num_heads = None, alibi_learned = False, - rel_pos_bias = False, - rel_pos_num_buckets = 32, - rel_pos_max_distance = 128, position_infused_attn = False, rotary_pos_emb = False, rotary_emb_dim = None, @@ -705,6 +711,7 @@ class AttentionLayers(nn.Module): self.depth = depth self.layers = nn.ModuleList([]) + rel_pos_bias = 'rel_pos_bias' in attn_kwargs self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None @@ -712,11 +719,8 @@ class AttentionLayers(nn.Module): self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' - assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' - if rel_pos_bias: - self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance) - elif alibi_pos_bias: + if alibi_pos_bias: alibi_num_heads = default(alibi_num_heads, heads) assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias @@ -858,8 +862,6 @@ class AttentionLayers(nn.Module): rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): - is_last = ind == (len(self.layers) - 1) - if layer_type == 'a': hiddens.append(x) layer_mem = mems.pop(0) if mems else None @@ -872,11 +874,13 @@ class AttentionLayers(nn.Module): x = pre_branch_norm(x) if layer_type == 'a': - out, inter = block(x, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem) + block_fn = functools.partial(block, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem) + out, inter = checkpoint(block_fn, x) elif layer_type == 'c': - out, inter = block(x, context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn) + block_fn = functools.partial(block, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn) + out, inter = checkpoint(block_fn, x) elif layer_type == 'f': - out = block(x) + out = checkpoint(block, x) if exists(post_branch_norm): out = post_branch_norm(out)