stable layernorm
This commit is contained in:
parent
6a2c29f596
commit
3853f37257
|
@ -211,15 +211,13 @@ class Wav2Vec2EncoderLayer(nn.Module):
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
||||||
attn_residual = hidden_states
|
attn_residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
hidden_states, attn_weights, _ = self.attention(
|
hidden_states, attn_weights, _ = self.attention(
|
||||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||||
)
|
)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = attn_residual + hidden_states
|
hidden_states = attn_residual + hidden_states
|
||||||
|
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
|
||||||
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
@ -331,7 +329,6 @@ class Wav2Vec2Encoder(nn.Module):
|
||||||
|
|
||||||
position_embeddings = self.pos_conv_embed(hidden_states)
|
position_embeddings = self.pos_conv_embed(hidden_states)
|
||||||
hidden_states = hidden_states + position_embeddings
|
hidden_states = hidden_states + position_embeddings
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
|
||||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||||
|
@ -349,6 +346,8 @@ class Wav2Vec2Encoder(nn.Module):
|
||||||
layer_outputs = checkpoint(layer_fn, hidden_states)
|
layer_outputs = checkpoint(layer_fn, hidden_states)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user