From d69a00e389b1a04f1989ac8e208cc7b06925b42d Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 14 Apr 2024 13:12:50 -0500 Subject: [PATCH] Properly pass retention_mask for retnet-HF, attempt to fix recurrent forward for retnet (doesn't work still) --- vall_e/ext/retnet_hf/modeling_retnet.py | 1 + vall_e/models/ar_nar.py | 2 +- vall_e/models/base.py | 56 +++++++++++++++---------- 3 files changed, 36 insertions(+), 23 deletions(-) diff --git a/vall_e/ext/retnet_hf/modeling_retnet.py b/vall_e/ext/retnet_hf/modeling_retnet.py index 2a2a580..ade5bfb 100644 --- a/vall_e/ext/retnet_hf/modeling_retnet.py +++ b/vall_e/ext/retnet_hf/modeling_retnet.py @@ -936,6 +936,7 @@ class RetNetModel(RetNetPreTrainedModel): for idx, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) + past_key_value = ( past_key_values[idx] if past_key_values is not None else None ) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 85dda43..0ad074e 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -209,7 +209,7 @@ class AR_NAR(Base): sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ] stopped = torch.zeros(batch_size, device=device).bool() - recurrent_state = {} if cfg.inference.recurrent_forward else None + recurrent_state = [] if cfg.inference.recurrent_forward else None mirostat = [ {"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0} ] * batch_size if sampling_mirostat_tau > 0.0 else None diff --git a/vall_e/models/base.py b/vall_e/models/base.py index de164e0..cb514a2 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -567,40 +567,48 @@ class Base(nn.Module): padding = torch.zeros(shape, dtype=x.dtype, device=x.device) m = torch.cat([m, padding], dim=1) - if state is not None and self.arch_type == "retnet": + # for simplicity + mask = m.squeeze(-1).int() + + """ + # Broken + if state is not None and (self.arch_type == "retnet" or self.arch_type == "retnet-hf"): # prefill if len(state) == 0: prefill_size = x.shape[1] - # run the initial prompt to fill the KV cache if self.arch_type == "retnet": for n in range(prefill_size): xi = x[:, n, :].unsqueeze(1) self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True) elif self.arch_type == "retnet-hf": + state = None for n in range(prefill_size): xi = x[:, n, :].unsqueeze(1) kwargs = dict( - #attention_mask=m, - inputs_embeds=x, - past_key_values=state[-1], - use_cache=state is not None, + attention_mask=mask, + inputs_embeds=xi, + past_key_values=state, + use_cache=True, + forward_impl='recurrent', # return_dict=True, ) out = self.model(**kwargs) - state.append(out.past_key_values) + state = out.past_key_values # grab last token(s) x = x[:, -1, :].unsqueeze(1) + """ + # HF transformer derived model - elif self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral": + if self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral": kwargs = dict( - #attention_mask=m, + attention_mask=mask, inputs_embeds=x, past_key_values=state, - use_cache=state is not None, + use_cache=True, # return_dict=True, ) if self.n_experts > 1 and targ_list is not None: @@ -624,35 +632,39 @@ class Base(nn.Module): x = self.sin_emb.add_pe(x) # pass our inputs through the transformer for block in self.blocks: - x = block(x, m, l) + x = block(x, mask, l) elif self.arch_type == "retnet": # pass our inputs through the RetNet x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True) if _ is not None and "l_aux" in _ and self.n_experts > 1: aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001 elif self.arch_type == "retnet-hf": + first = state is None or len(state) == 0 + kwargs = dict( - #attention_mask=m, - inputs_embeds=x, - past_key_values=state, - use_cache=False, #state is not None, - # return_dict=True, + attention_mask=mask, + inputs_embeds=x if first else x[:, -1, :].unsqueeze(1), + past_key_values=None if first else state, + use_cache=True, + forward_impl='parallel' if first else 'recurrent', + return_dict=True, ) - t = self.model(**kwargs) - - x = t[0] - + out = self.model(**kwargs) + x = out.last_hidden_state if state is not None: - state = t[1] + state = out.past_key_values + elif self.arch_type == "bitnet": x = self.model(x) + # output projection layer with masking + x = self.classifier(x) * m # Remove padding logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ] - + # compute loss if the target is given if targ_list is not None: target_list = self._samplewise_merge_tensors(