added experimental disjointed position IDs (because I *think* this might help because technically a sequence is made up of several parts, and the position embeddings shouldn't be unified)

This commit is contained in:
mrq 2024-07-16 19:52:41 -05:00
parent fe0f235335
commit 22fe53508c
2 changed files with 35 additions and 0 deletions

View File

@ -206,6 +206,7 @@ class ModelExperimentalSettings:
kv_heads: int = 0 # MHA or GQA (for supported backends)
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
# I really need to clean this up
@dataclass()

View File

@ -383,6 +383,9 @@ class Base(nn.Module):
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
self.unified_position_ids = unified_position_ids
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
@ -720,6 +723,7 @@ class Base(nn.Module):
self,
inputs,
mask = None,
position_ids = None,
state = None,
):
x = inputs
@ -732,6 +736,7 @@ class Base(nn.Module):
attention_mask=m,
inputs_embeds=x,
past_key_values=state,
position_ids=position_ids,
use_cache=True,
# return_dict=True,
)
@ -930,6 +935,31 @@ class Base(nn.Module):
return x_list
def inputs_to_position_ids(
self,
inputs: list,
mask: Tensor,
):
# shamelessly grabbed from modeling_llama.py
ids = mask.long().cumsum(-1) - 1
ids.masked_fill_( mask == 0, 1 )
# there's a better way
if not self.unified_position_ids:
x_list = []
for batch_index, batch_input in enumerate(inputs):
batch = torch.cat( [ torch.Tensor([*range( input.shape[0] + (0 if name == "resp" else 1) )]) for name, input in batch_input if name != "task" ] )
delta = ids[batch_index].shape[0] - batch.shape[0]
if delta > 0:
batch = torch.cat( [ batch, torch.Tensor([1] * delta) ] )
x_list.append( batch )
ids = torch.stack( x_list )
return ids.to(device=mask.device, dtype=torch.int32)
def calc_loss(
self,
inputs: list,
@ -1097,6 +1127,7 @@ class Base(nn.Module):
device = x.device
batch_size = len(x_list)
# pure AR
if quant_levels is None:
quant_levels = [ 0 for _ in range(batch_size) ]
@ -1115,11 +1146,14 @@ class Base(nn.Module):
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
m = torch.cat([m, padding], dim=1)
# needs to be done here as we still have our raw inputs
position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
x, state, aux_loss = self._forward(
inputs=x,
mask=m,
state=state,
position_ids=position_ids,
)
if self.classifiers is not None: