diff --git a/torchscale/component/feedforward_network.py b/torchscale/component/feedforward_network.py index 77aadb1..a423cdf 100644 --- a/torchscale/component/feedforward_network.py +++ b/torchscale/component/feedforward_network.py @@ -85,7 +85,7 @@ def get_activation_fn(activation): if activation == "relu": return F.relu elif activation == "gelu": - return lambda x: F.gelu(x.float()).type_as(x) + return F.gelu else: raise NotImplementedError @@ -121,7 +121,7 @@ class FeedForwardNetwork(nn.Module): x_shape = x.shape x = x.reshape(-1, x.size(-1)) x = self.fc1(x) - x = self.activation_fn(x) + x = self.activation_fn(x.float()).as_type(x) x = self.activation_dropout_module(x) if self.ffn_layernorm is not None: x = self.ffn_layernorm(x)