move NAR-len rvq level 0 to separate embedding

This commit is contained in:
mrq 2024-11-13 11:38:58 -06:00
parent 29e45be0b4
commit 269648605e
4 changed files with 29 additions and 10 deletions

View File

@ -260,6 +260,7 @@ class ModelExperimentalSettings:
masking_train_p: float = 0.0 # odds of training with masking
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
masking_separate_embeddings: bool = False
# classifier-free guidance shit
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training

View File

@ -204,6 +204,17 @@ def load_engines(training=True, **model_kwargs):
continue
state[k] = ml.resize_weight( state[k], tokens )
"""
if model.config.experimental.masking_separate_embeddings and "resps_emb.embeddings.8.weight" not in state:
state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone()
state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone()
del state['classifiers.proj.8.weight']
del state['classifiers.proj.8.bias']
state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone()
"""
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# load lora weights if exists

View File

@ -108,7 +108,7 @@ class AR_NAR(Base):
#p = math.acos(r) / (math.pi * 0.5)
#timesteps[i] = 1.0 - clamp(p, 0.0, 1.0)
timesteps[i] = random.random()
# trim resps to only contain all levels below the target level
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
@ -896,7 +896,7 @@ def example_usage():
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
batch_size = cfg.hyperparameters.batch_size
cfg.model.experimental.masking_train_p = 0.5
cfg.model.experimental.masking_train_p = 1.0
text_list = [ text ] * batch_size
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size

View File

@ -476,6 +476,8 @@ class Base(nn.Module):
layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1
layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1
masking_separate_embeddings = self.config.experimental.masking_separate_embeddings if self.config is not None else False
n_tasks = self.config.tasks if self.config is not None else 8
n_langs = self.config.langs if self.config is not None else 2
n_tones = self.config.tones if self.config is not None else 1
@ -484,7 +486,12 @@ class Base(nn.Module):
if "nar" not in self.capabilities:
n_resp_tokens = n_audio_tokens + 1
l_tokens = [n_resp_tokens] * self.n_resp_levels
# AR+NAR model / NAR-len model
# NAR-len model
elif "len" in self.capabilities and masking_separate_embeddings:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
# AR+NAR model
else:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
@ -495,6 +502,7 @@ class Base(nn.Module):
self.layerskip = layerskip
self.special_tasks = [ "len", "stt" ]
self.inject_timestep_embedding = False # results in bad output
self.masking_separate_embeddings = masking_separate_embeddings
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
@ -1182,7 +1190,7 @@ class Base(nn.Module):
embedding = self.resps_emb(
# if masked use masked token, else original token
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
offset = 0,
offset = -1 if self.masking_separate_embeddings else 0, # pick last
quant_level = 0,
)
# cheat-y way to handle performing STT across all levels
@ -1325,10 +1333,9 @@ class Base(nn.Module):
device = logits[0].device
batch_size = len(logits)
summed_embeddings_task = [ "stt" ]
#classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ]
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ]
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
@ -1623,15 +1630,15 @@ class Base(nn.Module):
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ]
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ]
if self.inject_timestep_embedding:
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ]
else:
timesteps = []
classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
output = self._forward(
inputs=x,
mask=mask,