diff --git a/codes/models/audio/tts/transformer_builders.py b/codes/models/audio/tts/transformer_builders.py index e215ac21..8ce96f38 100644 --- a/codes/models/audio/tts/transformer_builders.py +++ b/codes/models/audio/tts/transformer_builders.py @@ -59,8 +59,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text """ from transformers import GPT2Config, GPT2Model gpt_config = GPT2Config(vocab_size=256, # Unused. - n_positions=max_mel_seq_len+max_text_seq_len, - n_ctx=max_mel_seq_len+max_text_seq_len, + n_positions=1, + n_ctx=1, n_embd=model_dim, n_layer=layers, n_head=heads, @@ -72,8 +72,10 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) # Built-in token embeddings are unused. del gpt.wte - return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\ - None, None + + mel_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim) + text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim) + return gpt, mel_pos_emb, text_pos_emb, None, None def build_lr_performer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index 0cd78aa2..97f6bf34 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -3,10 +3,12 @@ import torch.nn as nn import torch.nn.functional as F from transformers import GPT2Config, GPT2PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions +from transformers.models.gpt2.modeling_gpt2 import GPT2Attention from transformers.utils.model_parallel_utils import get_device_map, assert_device_map from models.arch_util import AttentionBlock from models.audio.tts.transformer_builders import build_hf_gpt_transformer +from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb from trainer.networks import register_model from utils.util import opt_get @@ -183,6 +185,73 @@ class GPT2InferenceModel(GPT2PreTrainedModel): ) + +class GPT2AttentionWithRotaryEmbeddings(GPT2Attention): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx) + self.rotary_pos_emb = RotaryEmbedding(32) + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=False, + output_attentions=False, + ): + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # Apply rotary embeddings. This is the only difference between this implementation and the HF one. + rotary_pos_emb = self.rotary_pos_emb(hidden_states.shape[1], hidden_states.device) + l = rotary_pos_emb.shape[-1] + (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (query, key, value)) + ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) + query, key, value = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + class ConditioningEncoder(nn.Module): def __init__(self, spec_dim, @@ -239,7 +308,7 @@ class UnifiedVoice(nn.Module): mel_length_compression=1024, number_text_tokens=256, start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, - checkpointing=True, average_conditioning_embeddings=False): + checkpointing=True, average_conditioning_embeddings=False, use_rotary_embeddings=False): """ Args: layers: Number of layers in transformer stack. @@ -270,8 +339,8 @@ class UnifiedVoice(nn.Module): self.stop_mel_token = stop_mel_token self.layers = layers self.heads = heads - self.max_mel_tokens = max_mel_tokens - self.max_text_tokens = max_text_tokens + self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs + self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2 self.model_dim = model_dim self.max_conditioning_inputs = max_conditioning_inputs self.mel_length_compression = mel_length_compression @@ -283,7 +352,7 @@ class UnifiedVoice(nn.Module): else: self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ - build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing) + build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) if train_solo_embeddings: self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) @@ -291,6 +360,11 @@ class UnifiedVoice(nn.Module): self.mel_solo_embedding = 0 self.text_solo_embedding = 0 + if use_rotary_embeddings: + # We must re-build all the attention layers as type GPT2AttentionWithRotaryEmbeddings. + for blk in self.gpt.h: + blk.attn = GPT2AttentionWithRotaryEmbeddings(self.gpt.config, layer_idx=blk.attn.layer_idx) + self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.number_text_tokens) self.mel_head = nn.Linear(model_dim, self.number_mel_codes) @@ -371,9 +445,6 @@ class UnifiedVoice(nn.Module): If return_attentions is specified, only logits are returned. If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. """ - assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}' - assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}' - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # chopping the inputs by the maximum actual length. max_text_len = text_lengths.max() @@ -422,8 +493,6 @@ class UnifiedVoice(nn.Module): Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided). """ - assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}' - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # chopping the inputs by the maximum actual length. max_text_len = text_lengths.max() @@ -477,7 +546,10 @@ class UnifiedVoice(nn.Module): return loss_mel.mean() def inference_speech(self, speech_conditioning_input, text_inputs, return_attentions=False, **hf_generate_kwargs): - seq_length = self.max_mel_tokens + self.max_text_tokens + 2 + if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also + seq_length = 2002 # Arbitrary default. + else: + seq_length = self.max_mel_tokens + self.max_text_tokens + 2 if not hasattr(self, 'inference_model'): # TODO: Decouple gpt_config from this inference model. gpt_config = GPT2Config(vocab_size=self.max_mel_tokens, @@ -566,10 +638,11 @@ def register_unified_voice2(opt_net, opt): if __name__ == '__main__': - gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4) + gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, + use_rotary_embeddings=True, max_mel_tokens=-1, max_text_tokens=-1) l = gpt(torch.randn(2, 3, 80, 800), torch.randint(high=256, size=(2,120)), torch.tensor([32, 120]), torch.randint(high=8192, size=(2,250)), torch.tensor([250*256,195*256])) - gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) + #gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) diff --git a/codes/train.py b/codes/train.py index 70578721..ad4ce62a 100644 --- a/codes/train.py +++ b/codes/train.py @@ -327,7 +327,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clvp.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)