Modify x_transformers to do checkpointing and use relative positional biases
This commit is contained in:
parent
09879b434d
commit
37bdfe82b2
|
@ -1,17 +1,12 @@
|
||||||
import functools
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import GPT2PreTrainedModel, GPT2Config
|
from transformers import GPT2PreTrainedModel, GPT2Config
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||||
from x_transformers import TransformerWrapper, Encoder, Decoder
|
|
||||||
|
|
||||||
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
|
||||||
from models.arch_util import AttentionBlock
|
from models.arch_util import AttentionBlock
|
||||||
from scripts.audio.gen.speech_synthesis_utils import wav_to_mel
|
from models.lucidrains.x_transformers import TransformerWrapper, Encoder, Decoder
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import load_audio
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceModel(GPT2PreTrainedModel):
|
class InferenceModel(GPT2PreTrainedModel):
|
||||||
|
@ -92,7 +87,7 @@ class InferenceModel(GPT2PreTrainedModel):
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True)
|
hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True)
|
||||||
logits = self.transformer.decoder.transformer.to_logits(hidden_states)
|
logits = self.transformer.decoder.to_logits(hidden_states)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (logits, )
|
return (logits, )
|
||||||
|
@ -161,40 +156,6 @@ class ConditioningEncoder(nn.Module):
|
||||||
return h.mean(dim=2)
|
return h.mean(dim=2)
|
||||||
|
|
||||||
|
|
||||||
class CheckpointedLayer(nn.Module):
|
|
||||||
"""
|
|
||||||
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
|
||||||
checkpoint for all other args.
|
|
||||||
"""
|
|
||||||
def __init__(self, wrap):
|
|
||||||
super().__init__()
|
|
||||||
self.wrap = wrap
|
|
||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
|
||||||
partial = functools.partial(self.wrap, **kwargs)
|
|
||||||
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointedXTransformerWrapper(nn.Module):
|
|
||||||
"""
|
|
||||||
Wraps a TransformerWrapper and applies CheckpointedLayer to each layer.
|
|
||||||
"""
|
|
||||||
def __init__(self, checkpoint=True, **xtransformer_kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.transformer = TransformerWrapper(**xtransformer_kwargs)
|
|
||||||
|
|
||||||
if not checkpoint:
|
|
||||||
return
|
|
||||||
for i in range(len(self.transformer.attn_layers.layers)):
|
|
||||||
n, b, r = self.transformer.attn_layers.layers[i]
|
|
||||||
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
|
||||||
return self.transformer(x, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class AutoregressiveCodegen(nn.Module):
|
class AutoregressiveCodegen(nn.Module):
|
||||||
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
|
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -204,7 +165,7 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
self.max_text_token_id = num_text_tokens
|
self.max_text_token_id = num_text_tokens
|
||||||
self.max_mel_token_id = num_mel_tokens
|
self.max_mel_token_id = num_mel_tokens
|
||||||
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
|
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
|
||||||
self.encoder = CheckpointedXTransformerWrapper(
|
self.encoder = TransformerWrapper(
|
||||||
num_tokens=num_text_tokens,
|
num_tokens=num_text_tokens,
|
||||||
use_pos_emb=False,
|
use_pos_emb=False,
|
||||||
max_seq_len=-1,
|
max_seq_len=-1,
|
||||||
|
@ -218,9 +179,10 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
ff_glu=True,
|
ff_glu=True,
|
||||||
ff_mult=1,
|
ff_mult=1,
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
|
attn_rel_pos_bias=True,
|
||||||
))
|
))
|
||||||
self.encoder.transformer.to_logits = nn.Identity() # This is unused.
|
self.encoder.to_logits = nn.Identity() # This is unused.
|
||||||
self.decoder = CheckpointedXTransformerWrapper(
|
self.decoder = TransformerWrapper(
|
||||||
num_tokens=num_mel_tokens,
|
num_tokens=num_mel_tokens,
|
||||||
use_pos_emb=False,
|
use_pos_emb=False,
|
||||||
max_seq_len=-1,
|
max_seq_len=-1,
|
||||||
|
@ -235,6 +197,7 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
ff_mult=1,
|
ff_mult=1,
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
cross_attend=True,
|
cross_attend=True,
|
||||||
|
attn_rel_pos_bias=True,
|
||||||
))
|
))
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
@ -289,7 +252,7 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
||||||
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
|
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
|
||||||
**hf_generate_kwargs)
|
**hf_generate_kwargs)
|
||||||
return gen
|
return gen.sequences
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import functools
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
|
@ -14,6 +15,7 @@ from entmax import entmax15
|
||||||
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
|
from utils.util import checkpoint
|
||||||
|
|
||||||
DEFAULT_DIM_HEAD = 64
|
DEFAULT_DIM_HEAD = 64
|
||||||
|
|
||||||
|
@ -445,7 +447,10 @@ class Attention(nn.Module):
|
||||||
zero_init_output = False,
|
zero_init_output = False,
|
||||||
max_attend_past = None,
|
max_attend_past = None,
|
||||||
qk_norm = False,
|
qk_norm = False,
|
||||||
scale_init_value = None
|
scale_init_value = None,
|
||||||
|
rel_pos_bias = False,
|
||||||
|
rel_pos_num_buckets = 32,
|
||||||
|
rel_pos_max_distance = 128,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
|
@ -508,6 +513,11 @@ class Attention(nn.Module):
|
||||||
self.attn_on_attn = on_attn
|
self.attn_on_attn = on_attn
|
||||||
self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
|
self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
|
||||||
|
|
||||||
|
self.rel_pos_bias = rel_pos_bias
|
||||||
|
if rel_pos_bias:
|
||||||
|
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
||||||
|
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
|
||||||
|
|
||||||
# init output projection 0
|
# init output projection 0
|
||||||
if zero_init_output:
|
if zero_init_output:
|
||||||
init_zero_(self.to_out)
|
init_zero_(self.to_out)
|
||||||
|
@ -519,7 +529,6 @@ class Attention(nn.Module):
|
||||||
mask = None,
|
mask = None,
|
||||||
context_mask = None,
|
context_mask = None,
|
||||||
attn_mask = None,
|
attn_mask = None,
|
||||||
rel_pos = None,
|
|
||||||
sinusoidal_emb = None,
|
sinusoidal_emb = None,
|
||||||
rotary_pos_emb = None,
|
rotary_pos_emb = None,
|
||||||
prev_attn = None,
|
prev_attn = None,
|
||||||
|
@ -593,8 +602,8 @@ class Attention(nn.Module):
|
||||||
if talking_heads:
|
if talking_heads:
|
||||||
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
|
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
|
||||||
|
|
||||||
if exists(rel_pos):
|
if self.rel_pos_bias:
|
||||||
dots = rel_pos(dots)
|
dots = self.rel_pos(dots)
|
||||||
|
|
||||||
if exists(input_mask):
|
if exists(input_mask):
|
||||||
dots.masked_fill_(~input_mask, mask_value)
|
dots.masked_fill_(~input_mask, mask_value)
|
||||||
|
@ -673,9 +682,6 @@ class AttentionLayers(nn.Module):
|
||||||
alibi_pos_bias = False,
|
alibi_pos_bias = False,
|
||||||
alibi_num_heads = None,
|
alibi_num_heads = None,
|
||||||
alibi_learned = False,
|
alibi_learned = False,
|
||||||
rel_pos_bias = False,
|
|
||||||
rel_pos_num_buckets = 32,
|
|
||||||
rel_pos_max_distance = 128,
|
|
||||||
position_infused_attn = False,
|
position_infused_attn = False,
|
||||||
rotary_pos_emb = False,
|
rotary_pos_emb = False,
|
||||||
rotary_emb_dim = None,
|
rotary_emb_dim = None,
|
||||||
|
@ -705,6 +711,7 @@ class AttentionLayers(nn.Module):
|
||||||
self.depth = depth
|
self.depth = depth
|
||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
|
||||||
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
|
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
|
||||||
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
||||||
|
|
||||||
|
@ -712,11 +719,8 @@ class AttentionLayers(nn.Module):
|
||||||
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
|
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
|
||||||
|
|
||||||
assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
|
assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
|
||||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
|
||||||
|
|
||||||
if rel_pos_bias:
|
if alibi_pos_bias:
|
||||||
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
|
|
||||||
elif alibi_pos_bias:
|
|
||||||
alibi_num_heads = default(alibi_num_heads, heads)
|
alibi_num_heads = default(alibi_num_heads, heads)
|
||||||
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
||||||
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
|
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
|
||||||
|
@ -858,8 +862,6 @@ class AttentionLayers(nn.Module):
|
||||||
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
||||||
|
|
||||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
||||||
is_last = ind == (len(self.layers) - 1)
|
|
||||||
|
|
||||||
if layer_type == 'a':
|
if layer_type == 'a':
|
||||||
hiddens.append(x)
|
hiddens.append(x)
|
||||||
layer_mem = mems.pop(0) if mems else None
|
layer_mem = mems.pop(0) if mems else None
|
||||||
|
@ -872,11 +874,13 @@ class AttentionLayers(nn.Module):
|
||||||
x = pre_branch_norm(x)
|
x = pre_branch_norm(x)
|
||||||
|
|
||||||
if layer_type == 'a':
|
if layer_type == 'a':
|
||||||
out, inter = block(x, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
|
block_fn = functools.partial(block, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
|
||||||
|
out, inter = checkpoint(block_fn, x)
|
||||||
elif layer_type == 'c':
|
elif layer_type == 'c':
|
||||||
out, inter = block(x, context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
|
block_fn = functools.partial(block, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
|
||||||
|
out, inter = checkpoint(block_fn, x)
|
||||||
elif layer_type == 'f':
|
elif layer_type == 'f':
|
||||||
out = block(x)
|
out = checkpoint(block, x)
|
||||||
|
|
||||||
if exists(post_branch_norm):
|
if exists(post_branch_norm):
|
||||||
out = post_branch_norm(out)
|
out = post_branch_norm(out)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user