stable layernorm

This commit is contained in:
James Betker 2022-05-17 16:07:03 -06:00
parent 6a2c29f596
commit 3853f37257

View File

@ -211,15 +211,13 @@ class Wav2Vec2EncoderLayer(nn.Module):
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
outputs = (hidden_states,)
@ -331,7 +329,6 @@ class Wav2Vec2Encoder(nn.Module):
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
@ -349,6 +346,8 @@ class Wav2Vec2Encoder(nn.Module):
layer_outputs = checkpoint(layer_fn, hidden_states)
hidden_states = layer_outputs[0]
hidden_states = self.layer_norm(hidden_states)
return hidden_states