46 lines
1.4 KiB
Python
Executable File
46 lines
1.4 KiB
Python
Executable File
# https://github.com/microsoft/torchscale
|
|
|
|
from torchscale.architecture.config import RetNetConfig
|
|
from torchscale.architecture.retnet import RetNetDecoder
|
|
# from retnet import RetNet
|
|
|
|
# override MultiScaleRetention's forward because training with te throws an error
|
|
from torchscale.component.multiscale_retention import MultiScaleRetention, theta_shift
|
|
|
|
def MultiScaleRetention_forward(
|
|
self,
|
|
x,
|
|
rel_pos,
|
|
chunkwise_recurrent=False,
|
|
incremental_state=None
|
|
):
|
|
bsz, tgt_len, _ = x.size()
|
|
(sin, cos), inner_mask = rel_pos
|
|
|
|
q = self.q_proj(x)
|
|
k = self.k_proj(x) * self.scaling
|
|
v = self.v_proj(x)
|
|
g = self.g_proj(x)
|
|
|
|
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
|
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
|
|
|
qr = theta_shift(q, sin, cos)
|
|
kr = theta_shift(k, sin, cos)
|
|
|
|
if incremental_state is not None:
|
|
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
|
|
elif chunkwise_recurrent:
|
|
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
|
|
else:
|
|
output = self.parallel_forward(qr, kr, v, inner_mask)
|
|
|
|
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
|
|
|
|
output = self.gate_fn(g) * output
|
|
|
|
output = self.out_proj(output)
|
|
|
|
return output
|
|
|
|
MultiScaleRetention.forward = MultiScaleRetention_forward |