Properly pass retention_mask for retnet-HF, attempt to fix recurrent forward for retnet (doesn't work still)

This commit is contained in:
mrq 2024-04-14 13:12:50 -05:00
parent 789bb5d11b
commit d69a00e389
3 changed files with 36 additions and 23 deletions

View File

@ -936,6 +936,7 @@ class RetNetModel(RetNetPreTrainedModel):
for idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)

View File

@ -209,7 +209,7 @@ class AR_NAR(Base):
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
stopped = torch.zeros(batch_size, device=device).bool()
recurrent_state = {} if cfg.inference.recurrent_forward else None
recurrent_state = [] if cfg.inference.recurrent_forward else None
mirostat = [
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
] * batch_size if sampling_mirostat_tau > 0.0 else None

View File

@ -567,40 +567,48 @@ class Base(nn.Module):
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
m = torch.cat([m, padding], dim=1)
if state is not None and self.arch_type == "retnet":
# for simplicity
mask = m.squeeze(-1).int()
"""
# Broken
if state is not None and (self.arch_type == "retnet" or self.arch_type == "retnet-hf"):
# prefill
if len(state) == 0:
prefill_size = x.shape[1]
# run the initial prompt to fill the KV cache
if self.arch_type == "retnet":
for n in range(prefill_size):
xi = x[:, n, :].unsqueeze(1)
self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True)
elif self.arch_type == "retnet-hf":
state = None
for n in range(prefill_size):
xi = x[:, n, :].unsqueeze(1)
kwargs = dict(
#attention_mask=m,
inputs_embeds=x,
past_key_values=state[-1],
use_cache=state is not None,
attention_mask=mask,
inputs_embeds=xi,
past_key_values=state,
use_cache=True,
forward_impl='recurrent',
# return_dict=True,
)
out = self.model(**kwargs)
state.append(out.past_key_values)
state = out.past_key_values
# grab last token(s)
x = x[:, -1, :].unsqueeze(1)
"""
# HF transformer derived model
elif self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral":
if self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral":
kwargs = dict(
#attention_mask=m,
attention_mask=mask,
inputs_embeds=x,
past_key_values=state,
use_cache=state is not None,
use_cache=True,
# return_dict=True,
)
if self.n_experts > 1 and targ_list is not None:
@ -624,35 +632,39 @@ class Base(nn.Module):
x = self.sin_emb.add_pe(x)
# pass our inputs through the transformer
for block in self.blocks:
x = block(x, m, l)
x = block(x, mask, l)
elif self.arch_type == "retnet":
# pass our inputs through the RetNet
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
if _ is not None and "l_aux" in _ and self.n_experts > 1:
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
elif self.arch_type == "retnet-hf":
first = state is None or len(state) == 0
kwargs = dict(
#attention_mask=m,
inputs_embeds=x,
past_key_values=state,
use_cache=False, #state is not None,
# return_dict=True,
attention_mask=mask,
inputs_embeds=x if first else x[:, -1, :].unsqueeze(1),
past_key_values=None if first else state,
use_cache=True,
forward_impl='parallel' if first else 'recurrent',
return_dict=True,
)
t = self.model(**kwargs)
x = t[0]
out = self.model(**kwargs)
x = out.last_hidden_state
if state is not None:
state = t[1]
state = out.past_key_values
elif self.arch_type == "bitnet":
x = self.model(x)
# output projection layer with masking
x = self.classifier(x) * m
# Remove padding
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
# compute loss if the target is given
if targ_list is not None:
target_list = self._samplewise_merge_tensors(