Adding sqrt in the recurrent_forward of retnet to make it consistent with parallel_forward
Adding sqrt in the recurrent_forward of retnet can avoid numerical underflow thus improving consistency and performance. https://github.com/microsoft/torchscale/issues/47
This commit is contained in:
parent
7d231743f4
commit
7f0bf80a7e
|
@ -105,7 +105,7 @@ class MultiScaleRetention(nn.Module):
|
||||||
prev_kv = incremental_state["prev_key_value"]
|
prev_kv = incremental_state["prev_key_value"]
|
||||||
prev_scale = incremental_state["scale"]
|
prev_scale = incremental_state["scale"]
|
||||||
scale = prev_scale * decay + 1
|
scale = prev_scale * decay + 1
|
||||||
kv = prev_kv * (1 - 1 / scale).view(self.num_heads, 1, 1) + kv / scale.view(self.num_heads, 1, 1)
|
kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
|
||||||
# kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
|
# kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
|
||||||
else:
|
else:
|
||||||
scale = torch.ones_like(decay)
|
scale = torch.ones_like(decay)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user