move NAR-len rvq level 0 to separate embedding
This commit is contained in:
parent
29e45be0b4
commit
269648605e
|
@ -260,6 +260,7 @@ class ModelExperimentalSettings:
|
|||
|
||||
masking_train_p: float = 0.0 # odds of training with masking
|
||||
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
||||
masking_separate_embeddings: bool = False
|
||||
|
||||
# classifier-free guidance shit
|
||||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||
|
|
|
@ -204,6 +204,17 @@ def load_engines(training=True, **model_kwargs):
|
|||
continue
|
||||
state[k] = ml.resize_weight( state[k], tokens )
|
||||
|
||||
"""
|
||||
if model.config.experimental.masking_separate_embeddings and "resps_emb.embeddings.8.weight" not in state:
|
||||
state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone()
|
||||
state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone()
|
||||
|
||||
del state['classifiers.proj.8.weight']
|
||||
del state['classifiers.proj.8.bias']
|
||||
|
||||
state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone()
|
||||
"""
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
# load lora weights if exists
|
||||
|
|
|
@ -108,7 +108,7 @@ class AR_NAR(Base):
|
|||
#p = math.acos(r) / (math.pi * 0.5)
|
||||
#timesteps[i] = 1.0 - clamp(p, 0.0, 1.0)
|
||||
timesteps[i] = random.random()
|
||||
|
||||
|
||||
# trim resps to only contain all levels below the target level
|
||||
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||
|
||||
|
@ -896,7 +896,7 @@ def example_usage():
|
|||
|
||||
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
batch_size = cfg.hyperparameters.batch_size
|
||||
cfg.model.experimental.masking_train_p = 0.5
|
||||
cfg.model.experimental.masking_train_p = 1.0
|
||||
|
||||
text_list = [ text ] * batch_size
|
||||
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size
|
||||
|
|
|
@ -476,6 +476,8 @@ class Base(nn.Module):
|
|||
layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1
|
||||
layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1
|
||||
|
||||
masking_separate_embeddings = self.config.experimental.masking_separate_embeddings if self.config is not None else False
|
||||
|
||||
n_tasks = self.config.tasks if self.config is not None else 8
|
||||
n_langs = self.config.langs if self.config is not None else 2
|
||||
n_tones = self.config.tones if self.config is not None else 1
|
||||
|
@ -484,7 +486,12 @@ class Base(nn.Module):
|
|||
if "nar" not in self.capabilities:
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
# AR+NAR model / NAR-len model
|
||||
# NAR-len model
|
||||
elif "len" in self.capabilities and masking_separate_embeddings:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
||||
# AR+NAR model
|
||||
else:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
|
@ -495,6 +502,7 @@ class Base(nn.Module):
|
|||
self.layerskip = layerskip
|
||||
self.special_tasks = [ "len", "stt" ]
|
||||
self.inject_timestep_embedding = False # results in bad output
|
||||
self.masking_separate_embeddings = masking_separate_embeddings
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -1182,7 +1190,7 @@ class Base(nn.Module):
|
|||
embedding = self.resps_emb(
|
||||
# if masked use masked token, else original token
|
||||
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
|
||||
offset = 0,
|
||||
offset = -1 if self.masking_separate_embeddings else 0, # pick last
|
||||
quant_level = 0,
|
||||
)
|
||||
# cheat-y way to handle performing STT across all levels
|
||||
|
@ -1325,10 +1333,9 @@ class Base(nn.Module):
|
|||
device = logits[0].device
|
||||
batch_size = len(logits)
|
||||
summed_embeddings_task = [ "stt" ]
|
||||
#classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
|
||||
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ]
|
||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ]
|
||||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
|
@ -1623,15 +1630,15 @@ class Base(nn.Module):
|
|||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||
|
||||
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
||||
|
||||
is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ]
|
||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ]
|
||||
|
||||
if self.inject_timestep_embedding:
|
||||
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
|
||||
timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ]
|
||||
else:
|
||||
timesteps = []
|
||||
|
||||
classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
|
||||
output = self._forward(
|
||||
inputs=x,
|
||||
mask=mask,
|
||||
|
|
Loading…
Reference in New Issue
Block a user