added experimental disjointed position IDs (because I *think* this might help because technically a sequence is made up of several parts, and the position embeddings shouldn't be unified)
This commit is contained in:
parent
fe0f235335
commit
22fe53508c
|
@ -206,6 +206,7 @@ class ModelExperimentalSettings:
|
|||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
||||
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
||||
|
||||
# I really need to clean this up
|
||||
@dataclass()
|
||||
|
|
|
@ -383,6 +383,9 @@ class Base(nn.Module):
|
|||
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
||||
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
||||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||
|
||||
self.unified_position_ids = unified_position_ids
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -720,6 +723,7 @@ class Base(nn.Module):
|
|||
self,
|
||||
inputs,
|
||||
mask = None,
|
||||
position_ids = None,
|
||||
state = None,
|
||||
):
|
||||
x = inputs
|
||||
|
@ -732,6 +736,7 @@ class Base(nn.Module):
|
|||
attention_mask=m,
|
||||
inputs_embeds=x,
|
||||
past_key_values=state,
|
||||
position_ids=position_ids,
|
||||
use_cache=True,
|
||||
# return_dict=True,
|
||||
)
|
||||
|
@ -930,6 +935,31 @@ class Base(nn.Module):
|
|||
|
||||
return x_list
|
||||
|
||||
def inputs_to_position_ids(
|
||||
self,
|
||||
inputs: list,
|
||||
mask: Tensor,
|
||||
):
|
||||
# shamelessly grabbed from modeling_llama.py
|
||||
ids = mask.long().cumsum(-1) - 1
|
||||
ids.masked_fill_( mask == 0, 1 )
|
||||
|
||||
# there's a better way
|
||||
if not self.unified_position_ids:
|
||||
x_list = []
|
||||
for batch_index, batch_input in enumerate(inputs):
|
||||
batch = torch.cat( [ torch.Tensor([*range( input.shape[0] + (0 if name == "resp" else 1) )]) for name, input in batch_input if name != "task" ] )
|
||||
|
||||
delta = ids[batch_index].shape[0] - batch.shape[0]
|
||||
if delta > 0:
|
||||
batch = torch.cat( [ batch, torch.Tensor([1] * delta) ] )
|
||||
|
||||
x_list.append( batch )
|
||||
|
||||
ids = torch.stack( x_list )
|
||||
|
||||
return ids.to(device=mask.device, dtype=torch.int32)
|
||||
|
||||
def calc_loss(
|
||||
self,
|
||||
inputs: list,
|
||||
|
@ -1097,6 +1127,7 @@ class Base(nn.Module):
|
|||
device = x.device
|
||||
batch_size = len(x_list)
|
||||
|
||||
|
||||
# pure AR
|
||||
if quant_levels is None:
|
||||
quant_levels = [ 0 for _ in range(batch_size) ]
|
||||
|
@ -1115,11 +1146,14 @@ class Base(nn.Module):
|
|||
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
||||
m = torch.cat([m, padding], dim=1)
|
||||
|
||||
# needs to be done here as we still have our raw inputs
|
||||
position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
||||
|
||||
x, state, aux_loss = self._forward(
|
||||
inputs=x,
|
||||
mask=m,
|
||||
state=state,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
if self.classifiers is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user