This commit is contained in:
mrq 2024-11-10 12:48:41 -06:00
parent 9cb0b6901b
commit 9def34cd66
3 changed files with 13 additions and 5 deletions

View File

@ -508,6 +508,7 @@ def get_task_symmap():
"<eoe>": 7,
"<stt>": 8,
"<len>": 0, # fake
"<nse>": 6, # fake
"<cse>": 6, # fake
}

View File

@ -149,6 +149,9 @@ class AR_NAR(Base):
else:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
if task == "len":
quant_levels[i] = 0
# apply CFG (should probably only apply to NAR quant level 0)
if task not in text_task + ["len"]:
drop_text = False
@ -277,6 +280,10 @@ class AR_NAR(Base):
cfg_strength = 1.0
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9
start_temperature = temperature
start_noise = 0.0
end_noise = 1.0
# if we're denoising from an existing sequence
if denoise_start > 0.0 and resps_list is not None:
@ -292,13 +299,12 @@ class AR_NAR(Base):
quant_levels = [ level for _ in range(batch_size) ]
prev_list = [ input_ids ]
start_temperature = temperature
start_noise = 0.0
end_noise = 1.0
null_text = torch.tensor([1, 2], device=device, dtype=torch.int16)
null_prom = None
max_steps = math.floor(max_steps * (end_noise - start_noise))
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
# anneal temperature
temperature = start_temperature * (steps_until_x0 / max_steps)
@ -397,7 +403,7 @@ class AR_NAR(Base):
# update scores (conjugated to put the worst scores at the top)
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
if cfg.experimental:
if cfg.experimental and max_steps > 0:
print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
return input_ids
@ -1056,7 +1062,7 @@ def example_usage():
if task == "stt":
prom = [ task ]
else:
task = "tts"
task = "tts" if random.random() > 0.1 else "len"
texts.append( text )
proms.append( prom )

View File

@ -140,6 +140,7 @@ def run_eval(engines, eval_name, dl, args=None):
proms_list=batch["proms"],
lang_list=batch["lang"],
task_list=batch["task"],
training=False,
)
if engine.hyper_config.experimental.hf: