From 09ab1aa9bcb49717c90c6597628abb52cdc170d6 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 8 Apr 2022 16:18:35 -0600 Subject: [PATCH] revert rotary embeddings work I'm not really sure that this is going to work. I'd rather explore re-using what I've already trained --- codes/models/audio/tts/unified_voice2.py | 77 +----------------------- 1 file changed, 2 insertions(+), 75 deletions(-) diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index 97f6bf34..f0d0f6b7 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -185,73 +185,6 @@ class GPT2InferenceModel(GPT2PreTrainedModel): ) - -class GPT2AttentionWithRotaryEmbeddings(GPT2Attention): - def __init__(self, config, is_cross_attention=False, layer_idx=None): - super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx) - self.rotary_pos_emb = RotaryEmbedding(32) - - def forward( - self, - hidden_states, - layer_past=None, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - use_cache=False, - output_attentions=False, - ): - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn"): - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) - attention_mask = encoder_attention_mask - else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - if layer_past is not None: - past_key, past_value = layer_past - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - - if use_cache is True: - present = (key, value) - else: - present = None - - # Apply rotary embeddings. This is the only difference between this implementation and the HF one. - rotary_pos_emb = self.rotary_pos_emb(hidden_states.shape[1], hidden_states.device) - l = rotary_pos_emb.shape[-1] - (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (query, key, value)) - ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) - query, key, value = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) - - if self.reorder_and_upcast_attn: - attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) - else: - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) - - class ConditioningEncoder(nn.Module): def __init__(self, spec_dim, @@ -308,7 +241,7 @@ class UnifiedVoice(nn.Module): mel_length_compression=1024, number_text_tokens=256, start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, - checkpointing=True, average_conditioning_embeddings=False, use_rotary_embeddings=False): + checkpointing=True, average_conditioning_embeddings=False): """ Args: layers: Number of layers in transformer stack. @@ -360,11 +293,6 @@ class UnifiedVoice(nn.Module): self.mel_solo_embedding = 0 self.text_solo_embedding = 0 - if use_rotary_embeddings: - # We must re-build all the attention layers as type GPT2AttentionWithRotaryEmbeddings. - for blk in self.gpt.h: - blk.attn = GPT2AttentionWithRotaryEmbeddings(self.gpt.config, layer_idx=blk.attn.layer_idx) - self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.number_text_tokens) self.mel_head = nn.Linear(model_dim, self.number_mel_codes) @@ -638,8 +566,7 @@ def register_unified_voice2(opt_net, opt): if __name__ == '__main__': - gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, - use_rotary_embeddings=True, max_mel_tokens=-1, max_text_tokens=-1) + gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4) l = gpt(torch.randn(2, 3, 80, 800), torch.randint(high=256, size=(2,120)), torch.tensor([32, 120]),