fix bug
This commit is contained in:
parent
59fc5f7d3d
commit
05a9628309
|
@ -44,14 +44,17 @@ class RetNetRelPos(nn.Module):
|
||||||
mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf"))
|
mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf"))
|
||||||
mask = torch.exp(mask * self.decay[:, None, None])
|
mask = torch.exp(mask * self.decay[:, None, None])
|
||||||
mask = torch.nan_to_num(mask)
|
mask = torch.nan_to_num(mask)
|
||||||
|
|
||||||
|
value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True)
|
||||||
|
value_inner_decay = value_inner_decay.unsqueeze(-1)
|
||||||
scale = mask.sum(dim=-1, keepdim=True).sqrt()
|
scale = mask.sum(dim=-1, keepdim=True).sqrt()
|
||||||
mask = mask / scale
|
inner_mask = mask / scale
|
||||||
|
|
||||||
cross_decay = torch.exp(self.decay * self.recurrent_chunk_size)
|
cross_decay = torch.exp(self.decay * self.recurrent_chunk_size)
|
||||||
inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
|
query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
|
||||||
|
query_inner_decay = query_inner_decay[:, :, None] / (scale / mask[:, -1].sum(dim=-1)[:, None, None])
|
||||||
cross_decay = cross_decay[:, None, None]
|
cross_decay = cross_decay[:, None, None]
|
||||||
inner_decay = inner_decay[:, :, None] / (scale / scale[:, -1, None])
|
retention_rel_pos = ((sin, cos), (inner_mask, cross_decay, query_inner_decay, value_inner_decay))
|
||||||
retention_rel_pos = ((sin, cos), (mask, cross_decay, inner_decay))
|
|
||||||
else:
|
else:
|
||||||
index = torch.arange(slen).to(self.decay)
|
index = torch.arange(slen).to(self.decay)
|
||||||
sin = torch.sin(index[:, None] * self.angle[None, :])
|
sin = torch.sin(index[:, None] * self.angle[None, :])
|
||||||
|
@ -346,7 +349,6 @@ class RetNetDecoder(nn.Module):
|
||||||
slen = prev_output_tokens.size(1)
|
slen = prev_output_tokens.size(1)
|
||||||
# relative position
|
# relative position
|
||||||
retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
|
retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
|
||||||
retention_rel_pos_no_block = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=False)
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
inner_states = [x]
|
inner_states = [x]
|
||||||
|
|
||||||
|
@ -361,20 +363,12 @@ class RetNetDecoder(nn.Module):
|
||||||
if idx not in incremental_state:
|
if idx not in incremental_state:
|
||||||
incremental_state[idx] = {}
|
incremental_state[idx] = {}
|
||||||
|
|
||||||
x_no_block, _ = layer(
|
|
||||||
x,
|
|
||||||
incremental_state[idx] if incremental_state is not None else None,
|
|
||||||
retention_rel_pos=retention_rel_pos_no_block,
|
|
||||||
chunkwise_recurrent=False,
|
|
||||||
)
|
|
||||||
x, l_aux_i = layer(
|
x, l_aux_i = layer(
|
||||||
x,
|
x,
|
||||||
incremental_state[idx] if incremental_state is not None else None,
|
incremental_state[idx] if incremental_state is not None else None,
|
||||||
retention_rel_pos=retention_rel_pos,
|
retention_rel_pos=retention_rel_pos,
|
||||||
chunkwise_recurrent=self.chunkwise_recurrent,
|
chunkwise_recurrent=self.chunkwise_recurrent,
|
||||||
)
|
)
|
||||||
print(x[0], x_no_block[0], (x - x_no_block).abs().max(), (x - x_no_block).abs().sum())
|
|
||||||
exit()
|
|
||||||
l_aux.append(l_aux_i)
|
l_aux.append(l_aux_i)
|
||||||
inner_states.append(x)
|
inner_states.append(x)
|
||||||
|
|
||||||
|
|
|
@ -116,7 +116,7 @@ class MultiScaleRetention(nn.Module):
|
||||||
qr, kr, v,
|
qr, kr, v,
|
||||||
inner_mask
|
inner_mask
|
||||||
):
|
):
|
||||||
mask, cross_decay, inner_decay = inner_mask
|
mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask
|
||||||
bsz, tgt_len, embed_dim = v.size()
|
bsz, tgt_len, embed_dim = v.size()
|
||||||
chunk_len = mask.size(1)
|
chunk_len = mask.size(1)
|
||||||
num_chunks = tgt_len // chunk_len
|
num_chunks = tgt_len // chunk_len
|
||||||
|
@ -136,8 +136,7 @@ class MultiScaleRetention(nn.Module):
|
||||||
inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
|
inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
|
||||||
|
|
||||||
# reduce kv in one chunk
|
# reduce kv in one chunk
|
||||||
kv = kr_t @ (v * mask[:, -1, :, None])
|
kv = kr_t @ (v * value_inner_decay)
|
||||||
kv = kv.view(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim)
|
|
||||||
|
|
||||||
kv_recurrent = []
|
kv_recurrent = []
|
||||||
cross_scale = []
|
cross_scale = []
|
||||||
|
@ -158,7 +157,7 @@ class MultiScaleRetention(nn.Module):
|
||||||
align_inner_scale = all_scale / inner_scale
|
align_inner_scale = all_scale / inner_scale
|
||||||
align_cross_scale = all_scale / cross_scale
|
align_cross_scale = all_scale / cross_scale
|
||||||
|
|
||||||
cross_output = (qr * inner_decay) @ kv_recurrent
|
cross_output = (qr * query_inner_decay) @ kv_recurrent
|
||||||
output = inner_output / align_inner_scale + cross_output / align_cross_scale
|
output = inner_output / align_inner_scale + cross_output / align_cross_scale
|
||||||
# output = inner_output / cross_scale + cross_output / inner_scale
|
# output = inner_output / cross_scale + cross_output / inner_scale
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user