Merge pull request #50 from wangmengzhi/main-2
Adding sqrt in the recurrent_forward of retnet to make it consistent with parallel_forward
This commit is contained in:
commit
0faee72d6f
|
@ -105,7 +105,7 @@ class MultiScaleRetention(nn.Module):
|
|||
prev_kv = incremental_state["prev_key_value"]
|
||||
prev_scale = incremental_state["scale"]
|
||||
scale = prev_scale * decay + 1
|
||||
kv = prev_kv * (1 - 1 / scale).view(self.num_heads, 1, 1) + kv / scale.view(self.num_heads, 1, 1)
|
||||
kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
|
||||
# kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
|
||||
else:
|
||||
scale = torch.ones_like(decay)
|
||||
|
@ -202,4 +202,4 @@ class MultiScaleRetention(nn.Module):
|
|||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user