This commit is contained in:
sunyt32 2023-09-28 17:39:26 +00:00
parent 59fc5f7d3d
commit 05a9628309
2 changed files with 11 additions and 18 deletions

View File

@ -44,14 +44,17 @@ class RetNetRelPos(nn.Module):
mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf"))
mask = torch.exp(mask * self.decay[:, None, None])
mask = torch.nan_to_num(mask)
value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True)
value_inner_decay = value_inner_decay.unsqueeze(-1)
scale = mask.sum(dim=-1, keepdim=True).sqrt()
mask = mask / scale
inner_mask = mask / scale
cross_decay = torch.exp(self.decay * self.recurrent_chunk_size)
inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
query_inner_decay = query_inner_decay[:, :, None] / (scale / mask[:, -1].sum(dim=-1)[:, None, None])
cross_decay = cross_decay[:, None, None]
inner_decay = inner_decay[:, :, None] / (scale / scale[:, -1, None])
retention_rel_pos = ((sin, cos), (mask, cross_decay, inner_decay))
retention_rel_pos = ((sin, cos), (inner_mask, cross_decay, query_inner_decay, value_inner_decay))
else:
index = torch.arange(slen).to(self.decay)
sin = torch.sin(index[:, None] * self.angle[None, :])
@ -346,7 +349,6 @@ class RetNetDecoder(nn.Module):
slen = prev_output_tokens.size(1)
# relative position
retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
retention_rel_pos_no_block = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=False)
# decoder layers
inner_states = [x]
@ -360,21 +362,13 @@ class RetNetDecoder(nn.Module):
else:
if idx not in incremental_state:
incremental_state[idx] = {}
x_no_block, _ = layer(
x,
incremental_state[idx] if incremental_state is not None else None,
retention_rel_pos=retention_rel_pos_no_block,
chunkwise_recurrent=False,
)
x, l_aux_i = layer(
x,
incremental_state[idx] if incremental_state is not None else None,
retention_rel_pos=retention_rel_pos,
chunkwise_recurrent=self.chunkwise_recurrent,
)
print(x[0], x_no_block[0], (x - x_no_block).abs().max(), (x - x_no_block).abs().sum())
exit()
l_aux.append(l_aux_i)
inner_states.append(x)

View File

@ -116,7 +116,7 @@ class MultiScaleRetention(nn.Module):
qr, kr, v,
inner_mask
):
mask, cross_decay, inner_decay = inner_mask
mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask
bsz, tgt_len, embed_dim = v.size()
chunk_len = mask.size(1)
num_chunks = tgt_len // chunk_len
@ -136,8 +136,7 @@ class MultiScaleRetention(nn.Module):
inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
# reduce kv in one chunk
kv = kr_t @ (v * mask[:, -1, :, None])
kv = kv.view(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim)
kv = kr_t @ (v * value_inner_decay)
kv_recurrent = []
cross_scale = []
@ -158,7 +157,7 @@ class MultiScaleRetention(nn.Module):
align_inner_scale = all_scale / inner_scale
align_cross_scale = all_scale / cross_scale
cross_output = (qr * inner_decay) @ kv_recurrent
cross_output = (qr * query_inner_decay) @ kv_recurrent
output = inner_output / align_inner_scale + cross_output / align_cross_scale
# output = inner_output / cross_scale + cross_output / inner_scale