From e7d5ec2ad7400440f4cea0978af1a1a2060e4031 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 29 Nov 2022 00:02:26 +0100 Subject: [PATCH 1/3] remove lambda --- torchscale/component/feedforward_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchscale/component/feedforward_network.py b/torchscale/component/feedforward_network.py index 77aadb1..a423cdf 100644 --- a/torchscale/component/feedforward_network.py +++ b/torchscale/component/feedforward_network.py @@ -85,7 +85,7 @@ def get_activation_fn(activation): if activation == "relu": return F.relu elif activation == "gelu": - return lambda x: F.gelu(x.float()).type_as(x) + return F.gelu else: raise NotImplementedError @@ -121,7 +121,7 @@ class FeedForwardNetwork(nn.Module): x_shape = x.shape x = x.reshape(-1, x.size(-1)) x = self.fc1(x) - x = self.activation_fn(x) + x = self.activation_fn(x.float()).as_type(x) x = self.activation_dropout_module(x) if self.ffn_layernorm is not None: x = self.ffn_layernorm(x) From be14bc23a1e63bde6b0fd5706dea058aa16de1bd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 29 Nov 2022 00:11:02 +0100 Subject: [PATCH 2/3] typo --- torchscale/architecture/decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchscale/architecture/decoder.py b/torchscale/architecture/decoder.py index 0463800..704e438 100644 --- a/torchscale/architecture/decoder.py +++ b/torchscale/architecture/decoder.py @@ -400,7 +400,7 @@ class Decoder(nn.Module): ) x = x.transpose(0, 1) - # relative postion + # relative position self_attn_rel_pos_bias = None slen = prev_output_tokens.size(1) if self.self_attn_relative_position is not None: From c69aba2a730252a3cee1fe54b109ef32f6e959d0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 29 Nov 2022 00:11:38 +0100 Subject: [PATCH 3/3] fix call to activation_fn --- examples/fairseq/models/bert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fairseq/models/bert.py b/examples/fairseq/models/bert.py index 7bbe382..d3ffa3f 100644 --- a/examples/fairseq/models/bert.py +++ b/examples/fairseq/models/bert.py @@ -391,7 +391,7 @@ class ClassificationHead(nn.Module): x = features[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) - x = self.activation_fn(x) + x = self.activation_fn(x.float()).as_type(x) x = self.dropout(x) x = self.out_proj(x) return x @@ -418,7 +418,7 @@ class LMHead(nn.Module): features = features[masked_tokens, :] x = self.dense(features) - x = self.activation_fn(x) + x = self.activation_fn(x.float()).as_type(x) x = self.layer_norm(x) # project back to size of vocabulary with bias x = F.linear(x, self.weight) + self.bias