diff --git a/examples/fairseq/models/bert.py b/examples/fairseq/models/bert.py index d3ffa3f..42e2687 100644 --- a/examples/fairseq/models/bert.py +++ b/examples/fairseq/models/bert.py @@ -391,7 +391,7 @@ class ClassificationHead(nn.Module): x = features[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) - x = self.activation_fn(x.float()).as_type(x) + x = self.activation_fn(x.float()).type_as(x) x = self.dropout(x) x = self.out_proj(x) return x @@ -418,7 +418,7 @@ class LMHead(nn.Module): features = features[masked_tokens, :] x = self.dense(features) - x = self.activation_fn(x.float()).as_type(x) + x = self.activation_fn(x.float()).type_as(x) x = self.layer_norm(x) # project back to size of vocabulary with bias x = F.linear(x, self.weight) + self.bias diff --git a/torchscale/component/feedforward_network.py b/torchscale/component/feedforward_network.py index a423cdf..31c0651 100644 --- a/torchscale/component/feedforward_network.py +++ b/torchscale/component/feedforward_network.py @@ -121,7 +121,7 @@ class FeedForwardNetwork(nn.Module): x_shape = x.shape x = x.reshape(-1, x.size(-1)) x = self.fc1(x) - x = self.activation_fn(x.float()).as_type(x) + x = self.activation_fn(x.float()).type_as(x) x = self.activation_dropout_module(x) if self.ffn_layernorm is not None: x = self.ffn_layernorm(x)