fix bug
This commit is contained in:
parent
59fc5f7d3d
commit
05a9628309
|
@ -44,14 +44,17 @@ class RetNetRelPos(nn.Module):
|
|||
mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf"))
|
||||
mask = torch.exp(mask * self.decay[:, None, None])
|
||||
mask = torch.nan_to_num(mask)
|
||||
|
||||
value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True)
|
||||
value_inner_decay = value_inner_decay.unsqueeze(-1)
|
||||
scale = mask.sum(dim=-1, keepdim=True).sqrt()
|
||||
mask = mask / scale
|
||||
inner_mask = mask / scale
|
||||
|
||||
cross_decay = torch.exp(self.decay * self.recurrent_chunk_size)
|
||||
inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
|
||||
query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
|
||||
query_inner_decay = query_inner_decay[:, :, None] / (scale / mask[:, -1].sum(dim=-1)[:, None, None])
|
||||
cross_decay = cross_decay[:, None, None]
|
||||
inner_decay = inner_decay[:, :, None] / (scale / scale[:, -1, None])
|
||||
retention_rel_pos = ((sin, cos), (mask, cross_decay, inner_decay))
|
||||
retention_rel_pos = ((sin, cos), (inner_mask, cross_decay, query_inner_decay, value_inner_decay))
|
||||
else:
|
||||
index = torch.arange(slen).to(self.decay)
|
||||
sin = torch.sin(index[:, None] * self.angle[None, :])
|
||||
|
@ -346,7 +349,6 @@ class RetNetDecoder(nn.Module):
|
|||
slen = prev_output_tokens.size(1)
|
||||
# relative position
|
||||
retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
|
||||
retention_rel_pos_no_block = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=False)
|
||||
# decoder layers
|
||||
inner_states = [x]
|
||||
|
||||
|
@ -360,21 +362,13 @@ class RetNetDecoder(nn.Module):
|
|||
else:
|
||||
if idx not in incremental_state:
|
||||
incremental_state[idx] = {}
|
||||
|
||||
x_no_block, _ = layer(
|
||||
x,
|
||||
incremental_state[idx] if incremental_state is not None else None,
|
||||
retention_rel_pos=retention_rel_pos_no_block,
|
||||
chunkwise_recurrent=False,
|
||||
)
|
||||
|
||||
x, l_aux_i = layer(
|
||||
x,
|
||||
incremental_state[idx] if incremental_state is not None else None,
|
||||
retention_rel_pos=retention_rel_pos,
|
||||
chunkwise_recurrent=self.chunkwise_recurrent,
|
||||
)
|
||||
print(x[0], x_no_block[0], (x - x_no_block).abs().max(), (x - x_no_block).abs().sum())
|
||||
exit()
|
||||
l_aux.append(l_aux_i)
|
||||
inner_states.append(x)
|
||||
|
||||
|
|
|
@ -116,7 +116,7 @@ class MultiScaleRetention(nn.Module):
|
|||
qr, kr, v,
|
||||
inner_mask
|
||||
):
|
||||
mask, cross_decay, inner_decay = inner_mask
|
||||
mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask
|
||||
bsz, tgt_len, embed_dim = v.size()
|
||||
chunk_len = mask.size(1)
|
||||
num_chunks = tgt_len // chunk_len
|
||||
|
@ -136,8 +136,7 @@ class MultiScaleRetention(nn.Module):
|
|||
inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
|
||||
|
||||
# reduce kv in one chunk
|
||||
kv = kr_t @ (v * mask[:, -1, :, None])
|
||||
kv = kv.view(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim)
|
||||
kv = kr_t @ (v * value_inner_decay)
|
||||
|
||||
kv_recurrent = []
|
||||
cross_scale = []
|
||||
|
@ -158,7 +157,7 @@ class MultiScaleRetention(nn.Module):
|
|||
align_inner_scale = all_scale / inner_scale
|
||||
align_cross_scale = all_scale / cross_scale
|
||||
|
||||
cross_output = (qr * inner_decay) @ kv_recurrent
|
||||
cross_output = (qr * query_inner_decay) @ kv_recurrent
|
||||
output = inner_output / align_inner_scale + cross_output / align_cross_scale
|
||||
# output = inner_output / cross_scale + cross_output / inner_scale
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user