From 3853f372573b3520d01708a3587e9388da711acd Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 17 May 2022 16:07:03 -0600 Subject: [PATCH] stable layernorm --- codes/models/audio/mel2vec.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index cd984199..3f4637a5 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -211,15 +211,13 @@ class Wav2Vec2EncoderLayer(nn.Module): def forward(self, hidden_states, attention_mask=None, output_attentions=False): attn_residual = hidden_states + hidden_states = self.layer_norm(hidden_states) hidden_states, attn_weights, _ = self.attention( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions ) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states - - hidden_states = self.layer_norm(hidden_states) - hidden_states = hidden_states + self.feed_forward(hidden_states) - hidden_states = self.final_layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) outputs = (hidden_states,) @@ -331,7 +329,6 @@ class Wav2Vec2Encoder(nn.Module): position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings - hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() @@ -349,6 +346,8 @@ class Wav2Vec2Encoder(nn.Module): layer_outputs = checkpoint(layer_fn, hidden_states) hidden_states = layer_outputs[0] + hidden_states = self.layer_norm(hidden_states) + return hidden_states