vall-e/vall_e/models/arch/retnet.py

46 lines
1.4 KiB
Python
Raw Normal View History

# https://github.com/microsoft/torchscale
from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder
# from retnet import RetNet
# override MultiScaleRetention's forward because training with te throws an error
from torchscale.component.multiscale_retention import MultiScaleRetention, theta_shift
def MultiScaleRetention_forward(
self,
x,
rel_pos,
chunkwise_recurrent=False,
incremental_state=None
):
bsz, tgt_len, _ = x.size()
(sin, cos), inner_mask = rel_pos
q = self.q_proj(x)
k = self.k_proj(x) * self.scaling
v = self.v_proj(x)
g = self.g_proj(x)
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
qr = theta_shift(q, sin, cos)
kr = theta_shift(k, sin, cos)
if incremental_state is not None:
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
elif chunkwise_recurrent:
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
else:
output = self.parallel_forward(qr, kr, v, inner_mask)
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
output = self.gate_fn(g) * output
output = self.out_proj(output)
return output
MultiScaleRetention.forward = MultiScaleRetention_forward