Adding sqrt in the recurrent_forward of retnet to make it consistent with parallel_forward

Adding sqrt in the recurrent_forward of retnet can avoid numerical underflow thus improving consistency and performance. https://github.com/microsoft/torchscale/issues/47
This commit is contained in:
wangmengzhi 2023-08-04 08:18:10 +08:00 committed by GitHub
parent 7d231743f4
commit 7f0bf80a7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -105,7 +105,7 @@ class MultiScaleRetention(nn.Module):
prev_kv = incremental_state["prev_key_value"]
prev_scale = incremental_state["scale"]
scale = prev_scale * decay + 1
kv = prev_kv * (1 - 1 / scale).view(self.num_heads, 1, 1) + kv / scale.view(self.num_heads, 1, 1)
kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
# kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
else:
scale = torch.ones_like(decay)
@ -202,4 +202,4 @@ class MultiScaleRetention(nn.Module):
return output