unified_voice with rotary embeddings

This commit is contained in:
James Betker 2022-04-07 20:11:14 -06:00
parent 573e5552b9
commit 3f8d7955ef
3 changed files with 92 additions and 17 deletions

View File

@ -59,8 +59,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
"""
from transformers import GPT2Config, GPT2Model
gpt_config = GPT2Config(vocab_size=256, # Unused.
n_positions=max_mel_seq_len+max_text_seq_len,
n_ctx=max_mel_seq_len+max_text_seq_len,
n_positions=1,
n_ctx=1,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
@ -72,8 +72,10 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused.
del gpt.wte
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
None, None
mel_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
return gpt, mel_pos_emb, text_pos_emb, None, None
def build_lr_performer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):

View File

@ -3,10 +3,12 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from models.arch_util import AttentionBlock
from models.audio.tts.transformer_builders import build_hf_gpt_transformer
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
from trainer.networks import register_model
from utils.util import opt_get
@ -183,6 +185,73 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
)
class GPT2AttentionWithRotaryEmbeddings(GPT2Attention):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx)
self.rotary_pos_emb = RotaryEmbedding(32)
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
# Apply rotary embeddings. This is the only difference between this implementation and the HF one.
rotary_pos_emb = self.rotary_pos_emb(hidden_states.shape[1], hidden_states.device)
l = rotary_pos_emb.shape[-1]
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (query, key, value))
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
query, key, value = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
class ConditioningEncoder(nn.Module):
def __init__(self,
spec_dim,
@ -239,7 +308,7 @@ class UnifiedVoice(nn.Module):
mel_length_compression=1024, number_text_tokens=256,
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True, average_conditioning_embeddings=False):
checkpointing=True, average_conditioning_embeddings=False, use_rotary_embeddings=False):
"""
Args:
layers: Number of layers in transformer stack.
@ -270,8 +339,8 @@ class UnifiedVoice(nn.Module):
self.stop_mel_token = stop_mel_token
self.layers = layers
self.heads = heads
self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
@ -283,7 +352,7 @@ class UnifiedVoice(nn.Module):
else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing)
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
@ -291,6 +360,11 @@ class UnifiedVoice(nn.Module):
self.mel_solo_embedding = 0
self.text_solo_embedding = 0
if use_rotary_embeddings:
# We must re-build all the attention layers as type GPT2AttentionWithRotaryEmbeddings.
for blk in self.gpt.h:
blk.attn = GPT2AttentionWithRotaryEmbeddings(self.gpt.config, layer_idx=blk.attn.layer_idx)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
@ -371,9 +445,6 @@ class UnifiedVoice(nn.Module):
If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
"""
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
@ -422,8 +493,6 @@ class UnifiedVoice(nn.Module):
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
"""
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
@ -477,7 +546,10 @@ class UnifiedVoice(nn.Module):
return loss_mel.mean()
def inference_speech(self, speech_conditioning_input, text_inputs, return_attentions=False, **hf_generate_kwargs):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also
seq_length = 2002 # Arbitrary default.
else:
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
if not hasattr(self, 'inference_model'):
# TODO: Decouple gpt_config from this inference model.
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
@ -566,10 +638,11 @@ def register_unified_voice2(opt_net, opt):
if __name__ == '__main__':
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4,
use_rotary_embeddings=True, max_mel_tokens=-1, max_text_tokens=-1)
l = gpt(torch.randn(2, 3, 80, 800),
torch.randint(high=256, size=(2,120)),
torch.tensor([32, 120]),
torch.randint(high=8192, size=(2,250)),
torch.tensor([250*256,195*256]))
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
#gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))

View File

@ -327,7 +327,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clvp.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)