Merge pull request #50 from wangmengzhi/main-2

Adding sqrt in the recurrent_forward of retnet to make it consistent with parallel_forward
This commit is contained in:
Li Dong 2023-08-04 09:02:38 +08:00 committed by GitHub
commit 0faee72d6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -105,7 +105,7 @@ class MultiScaleRetention(nn.Module):
prev_kv = incremental_state["prev_key_value"]
prev_scale = incremental_state["scale"]
scale = prev_scale * decay + 1
kv = prev_kv * (1 - 1 / scale).view(self.num_heads, 1, 1) + kv / scale.view(self.num_heads, 1, 1)
kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
# kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
else:
scale = torch.ones_like(decay)
@ -202,4 +202,4 @@ class MultiScaleRetention(nn.Module):
return output