This commit is contained in:
mrq 2024-11-10 12:48:41 -06:00
parent 9cb0b6901b
commit 9def34cd66
3 changed files with 13 additions and 5 deletions

View File

@ -508,6 +508,7 @@ def get_task_symmap():
"<eoe>": 7, "<eoe>": 7,
"<stt>": 8, "<stt>": 8,
"<len>": 0, # fake
"<nse>": 6, # fake "<nse>": 6, # fake
"<cse>": 6, # fake "<cse>": 6, # fake
} }

View File

@ -149,6 +149,9 @@ class AR_NAR(Base):
else: else:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
if task == "len":
quant_levels[i] = 0
# apply CFG (should probably only apply to NAR quant level 0) # apply CFG (should probably only apply to NAR quant level 0)
if task not in text_task + ["len"]: if task not in text_task + ["len"]:
drop_text = False drop_text = False
@ -277,6 +280,10 @@ class AR_NAR(Base):
cfg_strength = 1.0 cfg_strength = 1.0
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied...... sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9 sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9
start_temperature = temperature
start_noise = 0.0
end_noise = 1.0
# if we're denoising from an existing sequence # if we're denoising from an existing sequence
if denoise_start > 0.0 and resps_list is not None: if denoise_start > 0.0 and resps_list is not None:
@ -292,13 +299,12 @@ class AR_NAR(Base):
quant_levels = [ level for _ in range(batch_size) ] quant_levels = [ level for _ in range(batch_size) ]
prev_list = [ input_ids ] prev_list = [ input_ids ]
start_temperature = temperature
start_noise = 0.0
end_noise = 1.0
null_text = torch.tensor([1, 2], device=device, dtype=torch.int16) null_text = torch.tensor([1, 2], device=device, dtype=torch.int16)
null_prom = None null_prom = None
max_steps = math.floor(max_steps * (end_noise - start_noise))
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))): for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
# anneal temperature # anneal temperature
temperature = start_temperature * (steps_until_x0 / max_steps) temperature = start_temperature * (steps_until_x0 / max_steps)
@ -397,7 +403,7 @@ class AR_NAR(Base):
# update scores (conjugated to put the worst scores at the top) # update scores (conjugated to put the worst scores at the top)
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device) scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
if cfg.experimental: if cfg.experimental and max_steps > 0:
print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores ) print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
return input_ids return input_ids
@ -1056,7 +1062,7 @@ def example_usage():
if task == "stt": if task == "stt":
prom = [ task ] prom = [ task ]
else: else:
task = "tts" task = "tts" if random.random() > 0.1 else "len"
texts.append( text ) texts.append( text )
proms.append( prom ) proms.append( prom )

View File

@ -140,6 +140,7 @@ def run_eval(engines, eval_name, dl, args=None):
proms_list=batch["proms"], proms_list=batch["proms"],
lang_list=batch["lang"], lang_list=batch["lang"],
task_list=batch["task"], task_list=batch["task"],
training=False,
) )
if engine.hyper_config.experimental.hf: if engine.hyper_config.experimental.hf: