Adding sqrt in the recurrent_forward of retnet to make it consistent with parallel_forward
Adding sqrt in the recurrent_forward of retnet can avoid numerical underflow thus improving consistency and performance. https://github.com/microsoft/torchscale/issues/47
This commit is contained in:
parent
7d231743f4
commit
7f0bf80a7e
|
@ -105,7 +105,7 @@ class MultiScaleRetention(nn.Module):
|
|||
prev_kv = incremental_state["prev_key_value"]
|
||||
prev_scale = incremental_state["scale"]
|
||||
scale = prev_scale * decay + 1
|
||||
kv = prev_kv * (1 - 1 / scale).view(self.num_heads, 1, 1) + kv / scale.view(self.num_heads, 1, 1)
|
||||
kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
|
||||
# kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
|
||||
else:
|
||||
scale = torch.ones_like(decay)
|
||||
|
@ -202,4 +202,4 @@ class MultiScaleRetention(nn.Module):
|
|||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user