Merge pull request #51 from sunyt32/retnet-official
fix chunkwise inconsistency bug
This commit is contained in:
commit
e2db7ae123
|
@ -147,20 +147,25 @@ class MultiScaleRetention(nn.Module):
|
|||
kv_recurrent = []
|
||||
cross_scale = []
|
||||
kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
|
||||
kv_scale = torch.ones(bsz, self.num_heads, 1, self.head_dim).to(v)
|
||||
kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v)
|
||||
|
||||
# accumulate kv by loop
|
||||
for i in range(num_chunks):
|
||||
kv_recurrent.append(kv_state / kv_scale)
|
||||
cross_scale.append(kv_scale)
|
||||
kv_state = kv_state * cross_decay + kv[:, i]
|
||||
kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).clamp(min=1)
|
||||
|
||||
kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).max(dim=-1, keepdim=True).values.clamp(min=1)
|
||||
|
||||
kv_recurrent = torch.stack(kv_recurrent, dim=1)
|
||||
cross_scale = torch.stack(cross_scale, dim=1)
|
||||
|
||||
all_scale = torch.maximum(inner_scale, cross_scale)
|
||||
align_inner_scale = all_scale / inner_scale
|
||||
align_cross_scale = all_scale / cross_scale
|
||||
|
||||
cross_output = (qr * inner_decay) @ kv_recurrent
|
||||
output = inner_output / cross_scale + cross_output / inner_scale
|
||||
output = inner_output / align_inner_scale + cross_output / align_cross_scale
|
||||
# output = inner_output / cross_scale + cross_output / inner_scale
|
||||
|
||||
output = output.transpose(2, 3)
|
||||
return output
|
||||
|
|
Loading…
Reference in New Issue
Block a user