Properly pass retention_mask for retnet-HF, attempt to fix recurrent forward for retnet (doesn't work still)
This commit is contained in:
parent
789bb5d11b
commit
d69a00e389
|
@ -936,6 +936,7 @@ class RetNetModel(RetNetPreTrainedModel):
|
||||||
for idx, layer in enumerate(self.layers):
|
for idx, layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
past_key_value = (
|
past_key_value = (
|
||||||
past_key_values[idx] if past_key_values is not None else None
|
past_key_values[idx] if past_key_values is not None else None
|
||||||
)
|
)
|
||||||
|
|
|
@ -209,7 +209,7 @@ class AR_NAR(Base):
|
||||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
recurrent_state = {} if cfg.inference.recurrent_forward else None
|
recurrent_state = [] if cfg.inference.recurrent_forward else None
|
||||||
mirostat = [
|
mirostat = [
|
||||||
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
||||||
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||||
|
|
|
@ -567,40 +567,48 @@ class Base(nn.Module):
|
||||||
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
||||||
m = torch.cat([m, padding], dim=1)
|
m = torch.cat([m, padding], dim=1)
|
||||||
|
|
||||||
if state is not None and self.arch_type == "retnet":
|
# for simplicity
|
||||||
|
mask = m.squeeze(-1).int()
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Broken
|
||||||
|
if state is not None and (self.arch_type == "retnet" or self.arch_type == "retnet-hf"):
|
||||||
# prefill
|
# prefill
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
prefill_size = x.shape[1]
|
prefill_size = x.shape[1]
|
||||||
|
|
||||||
# run the initial prompt to fill the KV cache
|
# run the initial prompt to fill the KV cache
|
||||||
if self.arch_type == "retnet":
|
if self.arch_type == "retnet":
|
||||||
for n in range(prefill_size):
|
for n in range(prefill_size):
|
||||||
xi = x[:, n, :].unsqueeze(1)
|
xi = x[:, n, :].unsqueeze(1)
|
||||||
self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True)
|
self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True)
|
||||||
elif self.arch_type == "retnet-hf":
|
elif self.arch_type == "retnet-hf":
|
||||||
|
state = None
|
||||||
for n in range(prefill_size):
|
for n in range(prefill_size):
|
||||||
xi = x[:, n, :].unsqueeze(1)
|
xi = x[:, n, :].unsqueeze(1)
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
#attention_mask=m,
|
attention_mask=mask,
|
||||||
inputs_embeds=x,
|
inputs_embeds=xi,
|
||||||
past_key_values=state[-1],
|
past_key_values=state,
|
||||||
use_cache=state is not None,
|
use_cache=True,
|
||||||
|
forward_impl='recurrent',
|
||||||
# return_dict=True,
|
# return_dict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = self.model(**kwargs)
|
out = self.model(**kwargs)
|
||||||
state.append(out.past_key_values)
|
state = out.past_key_values
|
||||||
|
|
||||||
# grab last token(s)
|
# grab last token(s)
|
||||||
x = x[:, -1, :].unsqueeze(1)
|
x = x[:, -1, :].unsqueeze(1)
|
||||||
|
"""
|
||||||
|
|
||||||
# HF transformer derived model
|
# HF transformer derived model
|
||||||
elif self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral":
|
if self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral":
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
#attention_mask=m,
|
attention_mask=mask,
|
||||||
inputs_embeds=x,
|
inputs_embeds=x,
|
||||||
past_key_values=state,
|
past_key_values=state,
|
||||||
use_cache=state is not None,
|
use_cache=True,
|
||||||
# return_dict=True,
|
# return_dict=True,
|
||||||
)
|
)
|
||||||
if self.n_experts > 1 and targ_list is not None:
|
if self.n_experts > 1 and targ_list is not None:
|
||||||
|
@ -624,35 +632,39 @@ class Base(nn.Module):
|
||||||
x = self.sin_emb.add_pe(x)
|
x = self.sin_emb.add_pe(x)
|
||||||
# pass our inputs through the transformer
|
# pass our inputs through the transformer
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, m, l)
|
x = block(x, mask, l)
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
# pass our inputs through the RetNet
|
# pass our inputs through the RetNet
|
||||||
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
|
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||||
if _ is not None and "l_aux" in _ and self.n_experts > 1:
|
if _ is not None and "l_aux" in _ and self.n_experts > 1:
|
||||||
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
|
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
|
||||||
elif self.arch_type == "retnet-hf":
|
elif self.arch_type == "retnet-hf":
|
||||||
|
first = state is None or len(state) == 0
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
#attention_mask=m,
|
attention_mask=mask,
|
||||||
inputs_embeds=x,
|
inputs_embeds=x if first else x[:, -1, :].unsqueeze(1),
|
||||||
past_key_values=state,
|
past_key_values=None if first else state,
|
||||||
use_cache=False, #state is not None,
|
use_cache=True,
|
||||||
# return_dict=True,
|
forward_impl='parallel' if first else 'recurrent',
|
||||||
|
return_dict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
t = self.model(**kwargs)
|
out = self.model(**kwargs)
|
||||||
|
x = out.last_hidden_state
|
||||||
x = t[0]
|
|
||||||
|
|
||||||
if state is not None:
|
if state is not None:
|
||||||
state = t[1]
|
state = out.past_key_values
|
||||||
|
|
||||||
elif self.arch_type == "bitnet":
|
elif self.arch_type == "bitnet":
|
||||||
x = self.model(x)
|
x = self.model(x)
|
||||||
|
|
||||||
# output projection layer with masking
|
# output projection layer with masking
|
||||||
|
|
||||||
x = self.classifier(x) * m
|
x = self.classifier(x) * m
|
||||||
|
|
||||||
# Remove padding
|
# Remove padding
|
||||||
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
if targ_list is not None:
|
if targ_list is not None:
|
||||||
target_list = self._samplewise_merge_tensors(
|
target_list = self._samplewise_merge_tensors(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user