From 22fe53508c18ec8a07521605f23d703ecd3c8390 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 16 Jul 2024 19:52:41 -0500 Subject: [PATCH] added experimental disjointed position IDs (because I *think* this might help because technically a sequence is made up of several parts, and the position embeddings shouldn't be unified) --- vall_e/config.py | 1 + vall_e/models/base.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/vall_e/config.py b/vall_e/config.py index e8f7560..856dd84 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -206,6 +206,7 @@ class ModelExperimentalSettings: kv_heads: int = 0 # MHA or GQA (for supported backends) p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range + unified_position_ids: bool = True # False will generate position IDs partitioned for each section # I really need to clean this up @dataclass() diff --git a/vall_e/models/base.py b/vall_e/models/base.py index a784d4a..ae7c45c 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -383,6 +383,9 @@ class Base(nn.Module): audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" + unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True + + self.unified_position_ids = unified_position_ids self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None @@ -720,6 +723,7 @@ class Base(nn.Module): self, inputs, mask = None, + position_ids = None, state = None, ): x = inputs @@ -732,6 +736,7 @@ class Base(nn.Module): attention_mask=m, inputs_embeds=x, past_key_values=state, + position_ids=position_ids, use_cache=True, # return_dict=True, ) @@ -930,6 +935,31 @@ class Base(nn.Module): return x_list + def inputs_to_position_ids( + self, + inputs: list, + mask: Tensor, + ): + # shamelessly grabbed from modeling_llama.py + ids = mask.long().cumsum(-1) - 1 + ids.masked_fill_( mask == 0, 1 ) + + # there's a better way + if not self.unified_position_ids: + x_list = [] + for batch_index, batch_input in enumerate(inputs): + batch = torch.cat( [ torch.Tensor([*range( input.shape[0] + (0 if name == "resp" else 1) )]) for name, input in batch_input if name != "task" ] ) + + delta = ids[batch_index].shape[0] - batch.shape[0] + if delta > 0: + batch = torch.cat( [ batch, torch.Tensor([1] * delta) ] ) + + x_list.append( batch ) + + ids = torch.stack( x_list ) + + return ids.to(device=mask.device, dtype=torch.int32) + def calc_loss( self, inputs: list, @@ -1097,6 +1127,7 @@ class Base(nn.Module): device = x.device batch_size = len(x_list) + # pure AR if quant_levels is None: quant_levels = [ 0 for _ in range(batch_size) ] @@ -1115,11 +1146,14 @@ class Base(nn.Module): padding = torch.zeros(shape, dtype=x.dtype, device=x.device) m = torch.cat([m, padding], dim=1) + # needs to be done here as we still have our raw inputs + position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None x, state, aux_loss = self._forward( inputs=x, mask=m, state=state, + position_ids=position_ids, ) if self.classifiers is not None: