Properly pass retention_mask for retnet-HF, attempt to fix recurrent forward for retnet (doesn't work still)
This commit is contained in:
parent
789bb5d11b
commit
d69a00e389
|
@ -936,6 +936,7 @@ class RetNetModel(RetNetPreTrainedModel):
|
|||
for idx, layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (
|
||||
past_key_values[idx] if past_key_values is not None else None
|
||||
)
|
||||
|
|
|
@ -209,7 +209,7 @@ class AR_NAR(Base):
|
|||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
recurrent_state = {} if cfg.inference.recurrent_forward else None
|
||||
recurrent_state = [] if cfg.inference.recurrent_forward else None
|
||||
mirostat = [
|
||||
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
||||
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||
|
|
|
@ -567,40 +567,48 @@ class Base(nn.Module):
|
|||
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
||||
m = torch.cat([m, padding], dim=1)
|
||||
|
||||
if state is not None and self.arch_type == "retnet":
|
||||
# for simplicity
|
||||
mask = m.squeeze(-1).int()
|
||||
|
||||
"""
|
||||
# Broken
|
||||
if state is not None and (self.arch_type == "retnet" or self.arch_type == "retnet-hf"):
|
||||
# prefill
|
||||
if len(state) == 0:
|
||||
prefill_size = x.shape[1]
|
||||
|
||||
# run the initial prompt to fill the KV cache
|
||||
if self.arch_type == "retnet":
|
||||
for n in range(prefill_size):
|
||||
xi = x[:, n, :].unsqueeze(1)
|
||||
self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True)
|
||||
elif self.arch_type == "retnet-hf":
|
||||
state = None
|
||||
for n in range(prefill_size):
|
||||
xi = x[:, n, :].unsqueeze(1)
|
||||
|
||||
kwargs = dict(
|
||||
#attention_mask=m,
|
||||
inputs_embeds=x,
|
||||
past_key_values=state[-1],
|
||||
use_cache=state is not None,
|
||||
attention_mask=mask,
|
||||
inputs_embeds=xi,
|
||||
past_key_values=state,
|
||||
use_cache=True,
|
||||
forward_impl='recurrent',
|
||||
# return_dict=True,
|
||||
)
|
||||
|
||||
out = self.model(**kwargs)
|
||||
state.append(out.past_key_values)
|
||||
state = out.past_key_values
|
||||
|
||||
# grab last token(s)
|
||||
x = x[:, -1, :].unsqueeze(1)
|
||||
"""
|
||||
|
||||
# HF transformer derived model
|
||||
elif self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral":
|
||||
if self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral":
|
||||
kwargs = dict(
|
||||
#attention_mask=m,
|
||||
attention_mask=mask,
|
||||
inputs_embeds=x,
|
||||
past_key_values=state,
|
||||
use_cache=state is not None,
|
||||
use_cache=True,
|
||||
# return_dict=True,
|
||||
)
|
||||
if self.n_experts > 1 and targ_list is not None:
|
||||
|
@ -624,35 +632,39 @@ class Base(nn.Module):
|
|||
x = self.sin_emb.add_pe(x)
|
||||
# pass our inputs through the transformer
|
||||
for block in self.blocks:
|
||||
x = block(x, m, l)
|
||||
x = block(x, mask, l)
|
||||
elif self.arch_type == "retnet":
|
||||
# pass our inputs through the RetNet
|
||||
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||
if _ is not None and "l_aux" in _ and self.n_experts > 1:
|
||||
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
|
||||
elif self.arch_type == "retnet-hf":
|
||||
first = state is None or len(state) == 0
|
||||
|
||||
kwargs = dict(
|
||||
#attention_mask=m,
|
||||
inputs_embeds=x,
|
||||
past_key_values=state,
|
||||
use_cache=False, #state is not None,
|
||||
# return_dict=True,
|
||||
attention_mask=mask,
|
||||
inputs_embeds=x if first else x[:, -1, :].unsqueeze(1),
|
||||
past_key_values=None if first else state,
|
||||
use_cache=True,
|
||||
forward_impl='parallel' if first else 'recurrent',
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
t = self.model(**kwargs)
|
||||
|
||||
x = t[0]
|
||||
|
||||
out = self.model(**kwargs)
|
||||
x = out.last_hidden_state
|
||||
if state is not None:
|
||||
state = t[1]
|
||||
state = out.past_key_values
|
||||
|
||||
elif self.arch_type == "bitnet":
|
||||
x = self.model(x)
|
||||
|
||||
# output projection layer with masking
|
||||
|
||||
x = self.classifier(x) * m
|
||||
|
||||
# Remove padding
|
||||
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
||||
|
||||
|
||||
# compute loss if the target is given
|
||||
if targ_list is not None:
|
||||
target_list = self._samplewise_merge_tensors(
|
||||
|
|
Loading…
Reference in New Issue
Block a user