unified_voice with rotary embeddings

This commit is contained in:
James Betker 2022-04-07 20:11:14 -06:00
parent 573e5552b9
commit 3f8d7955ef
3 changed files with 92 additions and 17 deletions

View File

@ -59,8 +59,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
""" """
from transformers import GPT2Config, GPT2Model from transformers import GPT2Config, GPT2Model
gpt_config = GPT2Config(vocab_size=256, # Unused. gpt_config = GPT2Config(vocab_size=256, # Unused.
n_positions=max_mel_seq_len+max_text_seq_len, n_positions=1,
n_ctx=max_mel_seq_len+max_text_seq_len, n_ctx=1,
n_embd=model_dim, n_embd=model_dim,
n_layer=layers, n_layer=layers,
n_head=heads, n_head=heads,
@ -72,8 +72,10 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused. # Built-in token embeddings are unused.
del gpt.wte del gpt.wte
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
None, None mel_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
return gpt, mel_pos_emb, text_pos_emb, None, None
def build_lr_performer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): def build_lr_performer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):

View File

@ -3,10 +3,12 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import GPT2Config, GPT2PreTrainedModel from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from models.arch_util import AttentionBlock from models.arch_util import AttentionBlock
from models.audio.tts.transformer_builders import build_hf_gpt_transformer from models.audio.tts.transformer_builders import build_hf_gpt_transformer
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
@ -183,6 +185,73 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
) )
class GPT2AttentionWithRotaryEmbeddings(GPT2Attention):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx)
self.rotary_pos_emb = RotaryEmbedding(32)
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
# Apply rotary embeddings. This is the only difference between this implementation and the HF one.
rotary_pos_emb = self.rotary_pos_emb(hidden_states.shape[1], hidden_states.device)
l = rotary_pos_emb.shape[-1]
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (query, key, value))
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
query, key, value = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
class ConditioningEncoder(nn.Module): class ConditioningEncoder(nn.Module):
def __init__(self, def __init__(self,
spec_dim, spec_dim,
@ -239,7 +308,7 @@ class UnifiedVoice(nn.Module):
mel_length_compression=1024, number_text_tokens=256, mel_length_compression=1024, number_text_tokens=256,
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True, average_conditioning_embeddings=False): checkpointing=True, average_conditioning_embeddings=False, use_rotary_embeddings=False):
""" """
Args: Args:
layers: Number of layers in transformer stack. layers: Number of layers in transformer stack.
@ -270,8 +339,8 @@ class UnifiedVoice(nn.Module):
self.stop_mel_token = stop_mel_token self.stop_mel_token = stop_mel_token
self.layers = layers self.layers = layers
self.heads = heads self.heads = heads
self.max_mel_tokens = max_mel_tokens self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs
self.max_text_tokens = max_text_tokens self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2
self.model_dim = model_dim self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
@ -283,7 +352,7 @@ class UnifiedVoice(nn.Module):
else: else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing) build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
if train_solo_embeddings: if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
@ -291,6 +360,11 @@ class UnifiedVoice(nn.Module):
self.mel_solo_embedding = 0 self.mel_solo_embedding = 0
self.text_solo_embedding = 0 self.text_solo_embedding = 0
if use_rotary_embeddings:
# We must re-build all the attention layers as type GPT2AttentionWithRotaryEmbeddings.
for blk in self.gpt.h:
blk.attn = GPT2AttentionWithRotaryEmbeddings(self.gpt.config, layer_idx=blk.attn.layer_idx)
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens) self.text_head = nn.Linear(model_dim, self.number_text_tokens)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
@ -371,9 +445,6 @@ class UnifiedVoice(nn.Module):
If return_attentions is specified, only logits are returned. If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
""" """
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length. # chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max() max_text_len = text_lengths.max()
@ -422,8 +493,6 @@ class UnifiedVoice(nn.Module):
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided). model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
""" """
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length. # chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max() max_text_len = text_lengths.max()
@ -477,7 +546,10 @@ class UnifiedVoice(nn.Module):
return loss_mel.mean() return loss_mel.mean()
def inference_speech(self, speech_conditioning_input, text_inputs, return_attentions=False, **hf_generate_kwargs): def inference_speech(self, speech_conditioning_input, text_inputs, return_attentions=False, **hf_generate_kwargs):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2 if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also
seq_length = 2002 # Arbitrary default.
else:
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
if not hasattr(self, 'inference_model'): if not hasattr(self, 'inference_model'):
# TODO: Decouple gpt_config from this inference model. # TODO: Decouple gpt_config from this inference model.
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens, gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
@ -566,10 +638,11 @@ def register_unified_voice2(opt_net, opt):
if __name__ == '__main__': if __name__ == '__main__':
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4) gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4,
use_rotary_embeddings=True, max_mel_tokens=-1, max_text_tokens=-1)
l = gpt(torch.randn(2, 3, 80, 800), l = gpt(torch.randn(2, 3, 80, 800),
torch.randint(high=256, size=(2,120)), torch.randint(high=256, size=(2,120)),
torch.tensor([32, 120]), torch.tensor([32, 120]),
torch.randint(high=8192, size=(2,250)), torch.randint(high=8192, size=(2,250)),
torch.tensor([250*256,195*256])) torch.tensor([250*256,195*256]))
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) #gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))

View File

@ -327,7 +327,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clvp.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args() args = parser.parse_args()
opt = option.parse(args.opt, is_train=True) opt = option.parse(args.opt, is_train=True)