forked from mrq/DL-Art-School
stable layernorm
This commit is contained in:
parent
6a2c29f596
commit
3853f37257
|
@ -211,15 +211,13 @@ class Wav2Vec2EncoderLayer(nn.Module):
|
|||
|
||||
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
||||
attn_residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states, attn_weights, _ = self.attention(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = attn_residual + hidden_states
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
|
@ -331,7 +329,6 @@ class Wav2Vec2Encoder(nn.Module):
|
|||
|
||||
position_embeddings = self.pos_conv_embed(hidden_states)
|
||||
hidden_states = hidden_states + position_embeddings
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
|
@ -349,6 +346,8 @@ class Wav2Vec2Encoder(nn.Module):
|
|||
layer_outputs = checkpoint(layer_fn, hidden_states)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user