remove lambda

This commit is contained in:
Kashif Rasul 2022-11-29 00:02:26 +01:00 committed by GitHub
parent c0ad46d7b8
commit e7d5ec2ad7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -85,7 +85,7 @@ def get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return lambda x: F.gelu(x.float()).type_as(x)
return F.gelu
else:
raise NotImplementedError
@ -121,7 +121,7 @@ class FeedForwardNetwork(nn.Module):
x_shape = x.shape
x = x.reshape(-1, x.size(-1))
x = self.fc1(x)
x = self.activation_fn(x)
x = self.activation_fn(x.float()).as_type(x)
x = self.activation_dropout_module(x)
if self.ffn_layernorm is not None:
x = self.ffn_layernorm(x)