diff --git a/torchscale/architecture/decoder.py b/torchscale/architecture/decoder.py index 2dea15f..5af981e 100644 --- a/torchscale/architecture/decoder.py +++ b/torchscale/architecture/decoder.py @@ -31,7 +31,7 @@ class DecoderLayer(nn.Module): super().__init__() self.args = args self.embed_dim = args.decoder_embed_dim - self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True) + self.dropout_module = torch.nn.Dropout(args.dropout) if args.drop_path_rate > 0: drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[ @@ -217,7 +217,7 @@ class Decoder(nn.Module): super().__init__(**kwargs) self.args = args - self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True) + self.dropout_module = torch.nn.Dropout(args.dropout) embed_dim = args.decoder_embed_dim self.embed_dim = embed_dim diff --git a/torchscale/architecture/encoder.py b/torchscale/architecture/encoder.py index c47238b..878b69b 100644 --- a/torchscale/architecture/encoder.py +++ b/torchscale/architecture/encoder.py @@ -29,7 +29,7 @@ class EncoderLayer(nn.Module): self.embed_dim = args.encoder_embed_dim self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) - self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True) + self.dropout_module = torch.nn.Dropout(args.dropout) if args.drop_path_rate > 0: drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[ @@ -174,7 +174,7 @@ class Encoder(nn.Module): self.args = args super().__init__(**kwargs) - self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True) + self.dropout_module = torch.nn.Dropout(args.dropout) embed_dim = args.encoder_embed_dim self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) diff --git a/torchscale/component/feedforward_network.py b/torchscale/component/feedforward_network.py index 0c872ce..abea43b 100644 --- a/torchscale/component/feedforward_network.py +++ b/torchscale/component/feedforward_network.py @@ -109,10 +109,8 @@ class FeedForwardNetwork(nn.Module): super().__init__() self.embed_dim = embed_dim self.activation_fn = get_activation_fn(activation=str(activation_fn)) - self.activation_dropout_module = torch.nn.Dropout( - activation_dropout, inplace=True - ) - self.dropout_module = torch.nn.Dropout(dropout, inplace=True) + self.activation_dropout_module = torch.nn.Dropout(activation_dropout) + self.dropout_module = torch.nn.Dropout(dropout) self.fc1 = nn.Linear(self.embed_dim, ffn_dim) self.fc2 = nn.Linear(ffn_dim, self.embed_dim) self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index d255596..392d0e9 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -48,7 +48,7 @@ class MultiheadAttention(nn.Module): if subln and self.self_attention else None ) - self.dropout_module = torch.nn.Dropout(dropout, inplace=True) + self.dropout_module = torch.nn.Dropout(dropout) self.xpos = ( XPOS(self.head_dim, args.xpos_scale_base) if args.xpos_rel_pos and self.self_attention