diff --git a/torchscale/architecture/decoder.py b/torchscale/architecture/decoder.py index b4a313f..37a7661 100644 --- a/torchscale/architecture/decoder.py +++ b/torchscale/architecture/decoder.py @@ -462,7 +462,7 @@ class Decoder(nn.Module): return x, { "inner_states": inner_states, "l_aux": l_aux, - "attn": [layer_attn.mean(dim=0)], + "attn": None, } def output_layer(self, features):