Merge pull request #51 from sunyt32/retnet-official

fix chunkwise inconsistency bug
This commit is contained in:
Shuming Ma 2023-08-04 13:51:53 +08:00 committed by GitHub
commit e2db7ae123
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -147,20 +147,25 @@ class MultiScaleRetention(nn.Module):
kv_recurrent = []
cross_scale = []
kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
kv_scale = torch.ones(bsz, self.num_heads, 1, self.head_dim).to(v)
kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v)
# accumulate kv by loop
for i in range(num_chunks):
kv_recurrent.append(kv_state / kv_scale)
cross_scale.append(kv_scale)
kv_state = kv_state * cross_decay + kv[:, i]
kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).clamp(min=1)
kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).max(dim=-1, keepdim=True).values.clamp(min=1)
kv_recurrent = torch.stack(kv_recurrent, dim=1)
cross_scale = torch.stack(cross_scale, dim=1)
all_scale = torch.maximum(inner_scale, cross_scale)
align_inner_scale = all_scale / inner_scale
align_cross_scale = all_scale / cross_scale
cross_output = (qr * inner_decay) @ kv_recurrent
output = inner_output / cross_scale + cross_output / inner_scale
output = inner_output / align_inner_scale + cross_output / align_cross_scale
# output = inner_output / cross_scale + cross_output / inner_scale
output = output.transpose(2, 3)
return output