Modify x_transformers to do checkpointing and use relative positional biases
This commit is contained in:
parent
09879b434d
commit
37bdfe82b2
|
@ -1,17 +1,12 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import GPT2PreTrainedModel, GPT2Config
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
from x_transformers import TransformerWrapper, Encoder, Decoder
|
||||
|
||||
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||
from models.arch_util import AttentionBlock
|
||||
from scripts.audio.gen.speech_synthesis_utils import wav_to_mel
|
||||
from models.lucidrains.x_transformers import TransformerWrapper, Encoder, Decoder
|
||||
from trainer.networks import register_model
|
||||
from utils.util import load_audio
|
||||
|
||||
|
||||
class InferenceModel(GPT2PreTrainedModel):
|
||||
|
@ -92,7 +87,7 @@ class InferenceModel(GPT2PreTrainedModel):
|
|||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True)
|
||||
logits = self.transformer.decoder.transformer.to_logits(hidden_states)
|
||||
logits = self.transformer.decoder.to_logits(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (logits, )
|
||||
|
@ -161,40 +156,6 @@ class ConditioningEncoder(nn.Module):
|
|||
return h.mean(dim=2)
|
||||
|
||||
|
||||
class CheckpointedLayer(nn.Module):
|
||||
"""
|
||||
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
||||
checkpoint for all other args.
|
||||
"""
|
||||
def __init__(self, wrap):
|
||||
super().__init__()
|
||||
self.wrap = wrap
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
||||
partial = functools.partial(self.wrap, **kwargs)
|
||||
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
||||
|
||||
|
||||
class CheckpointedXTransformerWrapper(nn.Module):
|
||||
"""
|
||||
Wraps a TransformerWrapper and applies CheckpointedLayer to each layer.
|
||||
"""
|
||||
def __init__(self, checkpoint=True, **xtransformer_kwargs):
|
||||
super().__init__()
|
||||
self.transformer = TransformerWrapper(**xtransformer_kwargs)
|
||||
|
||||
if not checkpoint:
|
||||
return
|
||||
for i in range(len(self.transformer.attn_layers.layers)):
|
||||
n, b, r = self.transformer.attn_layers.layers[i]
|
||||
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.transformer(x, **kwargs)
|
||||
|
||||
|
||||
class AutoregressiveCodegen(nn.Module):
|
||||
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
|
||||
super().__init__()
|
||||
|
@ -204,7 +165,7 @@ class AutoregressiveCodegen(nn.Module):
|
|||
self.max_text_token_id = num_text_tokens
|
||||
self.max_mel_token_id = num_mel_tokens
|
||||
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
|
||||
self.encoder = CheckpointedXTransformerWrapper(
|
||||
self.encoder = TransformerWrapper(
|
||||
num_tokens=num_text_tokens,
|
||||
use_pos_emb=False,
|
||||
max_seq_len=-1,
|
||||
|
@ -218,9 +179,10 @@ class AutoregressiveCodegen(nn.Module):
|
|||
ff_glu=True,
|
||||
ff_mult=1,
|
||||
rotary_pos_emb=True,
|
||||
attn_rel_pos_bias=True,
|
||||
))
|
||||
self.encoder.transformer.to_logits = nn.Identity() # This is unused.
|
||||
self.decoder = CheckpointedXTransformerWrapper(
|
||||
self.encoder.to_logits = nn.Identity() # This is unused.
|
||||
self.decoder = TransformerWrapper(
|
||||
num_tokens=num_mel_tokens,
|
||||
use_pos_emb=False,
|
||||
max_seq_len=-1,
|
||||
|
@ -235,6 +197,7 @@ class AutoregressiveCodegen(nn.Module):
|
|||
ff_mult=1,
|
||||
rotary_pos_emb=True,
|
||||
cross_attend=True,
|
||||
attn_rel_pos_bias=True,
|
||||
))
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
|
@ -289,7 +252,7 @@ class AutoregressiveCodegen(nn.Module):
|
|||
gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
||||
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
|
||||
**hf_generate_kwargs)
|
||||
return gen
|
||||
return gen.sequences
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import functools
|
||||
import math
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
|
@ -14,6 +15,7 @@ from entmax import entmax15
|
|||
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
||||
|
||||
# constants
|
||||
from utils.util import checkpoint
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
|
@ -445,7 +447,10 @@ class Attention(nn.Module):
|
|||
zero_init_output = False,
|
||||
max_attend_past = None,
|
||||
qk_norm = False,
|
||||
scale_init_value = None
|
||||
scale_init_value = None,
|
||||
rel_pos_bias = False,
|
||||
rel_pos_num_buckets = 32,
|
||||
rel_pos_max_distance = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
|
@ -508,6 +513,11 @@ class Attention(nn.Module):
|
|||
self.attn_on_attn = on_attn
|
||||
self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
|
||||
|
||||
self.rel_pos_bias = rel_pos_bias
|
||||
if rel_pos_bias:
|
||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
||||
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
|
||||
|
||||
# init output projection 0
|
||||
if zero_init_output:
|
||||
init_zero_(self.to_out)
|
||||
|
@ -519,7 +529,6 @@ class Attention(nn.Module):
|
|||
mask = None,
|
||||
context_mask = None,
|
||||
attn_mask = None,
|
||||
rel_pos = None,
|
||||
sinusoidal_emb = None,
|
||||
rotary_pos_emb = None,
|
||||
prev_attn = None,
|
||||
|
@ -593,8 +602,8 @@ class Attention(nn.Module):
|
|||
if talking_heads:
|
||||
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
|
||||
|
||||
if exists(rel_pos):
|
||||
dots = rel_pos(dots)
|
||||
if self.rel_pos_bias:
|
||||
dots = self.rel_pos(dots)
|
||||
|
||||
if exists(input_mask):
|
||||
dots.masked_fill_(~input_mask, mask_value)
|
||||
|
@ -673,9 +682,6 @@ class AttentionLayers(nn.Module):
|
|||
alibi_pos_bias = False,
|
||||
alibi_num_heads = None,
|
||||
alibi_learned = False,
|
||||
rel_pos_bias = False,
|
||||
rel_pos_num_buckets = 32,
|
||||
rel_pos_max_distance = 128,
|
||||
position_infused_attn = False,
|
||||
rotary_pos_emb = False,
|
||||
rotary_emb_dim = None,
|
||||
|
@ -705,6 +711,7 @@ class AttentionLayers(nn.Module):
|
|||
self.depth = depth
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
|
||||
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
|
||||
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
||||
|
||||
|
@ -712,11 +719,8 @@ class AttentionLayers(nn.Module):
|
|||
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
|
||||
|
||||
assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
|
||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
||||
|
||||
if rel_pos_bias:
|
||||
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
|
||||
elif alibi_pos_bias:
|
||||
if alibi_pos_bias:
|
||||
alibi_num_heads = default(alibi_num_heads, heads)
|
||||
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
||||
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
|
||||
|
@ -858,8 +862,6 @@ class AttentionLayers(nn.Module):
|
|||
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
||||
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
||||
is_last = ind == (len(self.layers) - 1)
|
||||
|
||||
if layer_type == 'a':
|
||||
hiddens.append(x)
|
||||
layer_mem = mems.pop(0) if mems else None
|
||||
|
@ -872,11 +874,13 @@ class AttentionLayers(nn.Module):
|
|||
x = pre_branch_norm(x)
|
||||
|
||||
if layer_type == 'a':
|
||||
out, inter = block(x, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
|
||||
block_fn = functools.partial(block, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
|
||||
out, inter = checkpoint(block_fn, x)
|
||||
elif layer_type == 'c':
|
||||
out, inter = block(x, context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
|
||||
block_fn = functools.partial(block, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
|
||||
out, inter = checkpoint(block_fn, x)
|
||||
elif layer_type == 'f':
|
||||
out = block(x)
|
||||
out = checkpoint(block, x)
|
||||
|
||||
if exists(post_branch_norm):
|
||||
out = post_branch_norm(out)
|
||||
|
|
Loading…
Reference in New Issue
Block a user