forked from mrq/DL-Art-School
revert rotary embeddings work
I'm not really sure that this is going to work. I'd rather explore re-using what I've already trained
This commit is contained in:
parent
2fb9ffb0aa
commit
09ab1aa9bc
|
@ -185,73 +185,6 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
)
|
||||
|
||||
|
||||
|
||||
class GPT2AttentionWithRotaryEmbeddings(GPT2Attention):
|
||||
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
||||
super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx)
|
||||
self.rotary_pos_emb = RotaryEmbedding(32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
if encoder_hidden_states is not None:
|
||||
if not hasattr(self, "q_attn"):
|
||||
raise ValueError(
|
||||
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
||||
)
|
||||
|
||||
query = self.q_attn(hidden_states)
|
||||
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||
attention_mask = encoder_attention_mask
|
||||
else:
|
||||
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||
|
||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key, value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
# Apply rotary embeddings. This is the only difference between this implementation and the HF one.
|
||||
rotary_pos_emb = self.rotary_pos_emb(hidden_states.shape[1], hidden_states.device)
|
||||
l = rotary_pos_emb.shape[-1]
|
||||
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (query, key, value))
|
||||
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
|
||||
query, key, value = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
|
||||
|
||||
if self.reorder_and_upcast_attn:
|
||||
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
|
||||
else:
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs # a, present, (attentions)
|
||||
|
||||
|
||||
class ConditioningEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
spec_dim,
|
||||
|
@ -308,7 +241,7 @@ class UnifiedVoice(nn.Module):
|
|||
mel_length_compression=1024, number_text_tokens=256,
|
||||
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
||||
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||
checkpointing=True, average_conditioning_embeddings=False, use_rotary_embeddings=False):
|
||||
checkpointing=True, average_conditioning_embeddings=False):
|
||||
"""
|
||||
Args:
|
||||
layers: Number of layers in transformer stack.
|
||||
|
@ -360,11 +293,6 @@ class UnifiedVoice(nn.Module):
|
|||
self.mel_solo_embedding = 0
|
||||
self.text_solo_embedding = 0
|
||||
|
||||
if use_rotary_embeddings:
|
||||
# We must re-build all the attention layers as type GPT2AttentionWithRotaryEmbeddings.
|
||||
for blk in self.gpt.h:
|
||||
blk.attn = GPT2AttentionWithRotaryEmbeddings(self.gpt.config, layer_idx=blk.attn.layer_idx)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||
|
@ -638,8 +566,7 @@ def register_unified_voice2(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4,
|
||||
use_rotary_embeddings=True, max_mel_tokens=-1, max_text_tokens=-1)
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
|
||||
l = gpt(torch.randn(2, 3, 80, 800),
|
||||
torch.randint(high=256, size=(2,120)),
|
||||
torch.tensor([32, 120]),
|
||||
|
|
Loading…
Reference in New Issue
Block a user