Modify x_transformers to do checkpointing and use relative positional biases

This commit is contained in:
James Betker 2022-04-06 00:35:29 -06:00
parent 09879b434d
commit 37bdfe82b2
2 changed files with 28 additions and 61 deletions

View File

@ -1,17 +1,12 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2PreTrainedModel, GPT2Config
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from x_transformers import TransformerWrapper, Encoder, Decoder
from data.audio.voice_tokenizer import VoiceBpeTokenizer
from models.arch_util import AttentionBlock
from scripts.audio.gen.speech_synthesis_utils import wav_to_mel
from models.lucidrains.x_transformers import TransformerWrapper, Encoder, Decoder
from trainer.networks import register_model
from utils.util import load_audio
class InferenceModel(GPT2PreTrainedModel):
@ -92,7 +87,7 @@ class InferenceModel(GPT2PreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True)
logits = self.transformer.decoder.transformer.to_logits(hidden_states)
logits = self.transformer.decoder.to_logits(hidden_states)
if not return_dict:
return (logits, )
@ -161,40 +156,6 @@ class ConditioningEncoder(nn.Module):
return h.mean(dim=2)
class CheckpointedLayer(nn.Module):
"""
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
checkpoint for all other args.
"""
def __init__(self, wrap):
super().__init__()
self.wrap = wrap
def forward(self, x, *args, **kwargs):
for k, v in kwargs.items():
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
partial = functools.partial(self.wrap, **kwargs)
return torch.utils.checkpoint.checkpoint(partial, x, *args)
class CheckpointedXTransformerWrapper(nn.Module):
"""
Wraps a TransformerWrapper and applies CheckpointedLayer to each layer.
"""
def __init__(self, checkpoint=True, **xtransformer_kwargs):
super().__init__()
self.transformer = TransformerWrapper(**xtransformer_kwargs)
if not checkpoint:
return
for i in range(len(self.transformer.attn_layers.layers)):
n, b, r = self.transformer.attn_layers.layers[i]
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
def forward(self, x, **kwargs):
return self.transformer(x, **kwargs)
class AutoregressiveCodegen(nn.Module):
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
super().__init__()
@ -204,7 +165,7 @@ class AutoregressiveCodegen(nn.Module):
self.max_text_token_id = num_text_tokens
self.max_mel_token_id = num_mel_tokens
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
self.encoder = CheckpointedXTransformerWrapper(
self.encoder = TransformerWrapper(
num_tokens=num_text_tokens,
use_pos_emb=False,
max_seq_len=-1,
@ -218,9 +179,10 @@ class AutoregressiveCodegen(nn.Module):
ff_glu=True,
ff_mult=1,
rotary_pos_emb=True,
attn_rel_pos_bias=True,
))
self.encoder.transformer.to_logits = nn.Identity() # This is unused.
self.decoder = CheckpointedXTransformerWrapper(
self.encoder.to_logits = nn.Identity() # This is unused.
self.decoder = TransformerWrapper(
num_tokens=num_mel_tokens,
use_pos_emb=False,
max_seq_len=-1,
@ -235,6 +197,7 @@ class AutoregressiveCodegen(nn.Module):
ff_mult=1,
rotary_pos_emb=True,
cross_attend=True,
attn_rel_pos_bias=True,
))
def get_grad_norm_parameter_groups(self):
@ -289,7 +252,7 @@ class AutoregressiveCodegen(nn.Module):
gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
**hf_generate_kwargs)
return gen
return gen.sequences
@register_model

View File

@ -1,3 +1,4 @@
import functools
import math
import torch
from torch import nn, einsum
@ -14,6 +15,7 @@ from entmax import entmax15
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
# constants
from utils.util import checkpoint
DEFAULT_DIM_HEAD = 64
@ -445,7 +447,10 @@ class Attention(nn.Module):
zero_init_output = False,
max_attend_past = None,
qk_norm = False,
scale_init_value = None
scale_init_value = None,
rel_pos_bias = False,
rel_pos_num_buckets = 32,
rel_pos_max_distance = 128,
):
super().__init__()
self.scale = dim_head ** -0.5
@ -508,6 +513,11 @@ class Attention(nn.Module):
self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
self.rel_pos_bias = rel_pos_bias
if rel_pos_bias:
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
# init output projection 0
if zero_init_output:
init_zero_(self.to_out)
@ -519,7 +529,6 @@ class Attention(nn.Module):
mask = None,
context_mask = None,
attn_mask = None,
rel_pos = None,
sinusoidal_emb = None,
rotary_pos_emb = None,
prev_attn = None,
@ -593,8 +602,8 @@ class Attention(nn.Module):
if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
if exists(rel_pos):
dots = rel_pos(dots)
if self.rel_pos_bias:
dots = self.rel_pos(dots)
if exists(input_mask):
dots.masked_fill_(~input_mask, mask_value)
@ -673,9 +682,6 @@ class AttentionLayers(nn.Module):
alibi_pos_bias = False,
alibi_num_heads = None,
alibi_learned = False,
rel_pos_bias = False,
rel_pos_num_buckets = 32,
rel_pos_max_distance = 128,
position_infused_attn = False,
rotary_pos_emb = False,
rotary_emb_dim = None,
@ -705,6 +711,7 @@ class AttentionLayers(nn.Module):
self.depth = depth
self.layers = nn.ModuleList([])
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
@ -712,11 +719,8 @@ class AttentionLayers(nn.Module):
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
if rel_pos_bias:
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
elif alibi_pos_bias:
if alibi_pos_bias:
alibi_num_heads = default(alibi_num_heads, heads)
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
@ -858,8 +862,6 @@ class AttentionLayers(nn.Module):
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
is_last = ind == (len(self.layers) - 1)
if layer_type == 'a':
hiddens.append(x)
layer_mem = mems.pop(0) if mems else None
@ -872,11 +874,13 @@ class AttentionLayers(nn.Module):
x = pre_branch_norm(x)
if layer_type == 'a':
out, inter = block(x, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
block_fn = functools.partial(block, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
out, inter = checkpoint(block_fn, x)
elif layer_type == 'c':
out, inter = block(x, context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
block_fn = functools.partial(block, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
out, inter = checkpoint(block_fn, x)
elif layer_type == 'f':
out = block(x)
out = checkpoint(block, x)
if exists(post_branch_norm):
out = post_branch_norm(out)