lol
This commit is contained in:
parent
9cb0b6901b
commit
9def34cd66
|
@ -508,6 +508,7 @@ def get_task_symmap():
|
|||
"<eoe>": 7,
|
||||
"<stt>": 8,
|
||||
|
||||
"<len>": 0, # fake
|
||||
"<nse>": 6, # fake
|
||||
"<cse>": 6, # fake
|
||||
}
|
||||
|
|
|
@ -149,6 +149,9 @@ class AR_NAR(Base):
|
|||
else:
|
||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||
|
||||
if task == "len":
|
||||
quant_levels[i] = 0
|
||||
|
||||
# apply CFG (should probably only apply to NAR quant level 0)
|
||||
if task not in text_task + ["len"]:
|
||||
drop_text = False
|
||||
|
@ -277,6 +280,10 @@ class AR_NAR(Base):
|
|||
cfg_strength = 1.0
|
||||
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
|
||||
sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9
|
||||
|
||||
start_temperature = temperature
|
||||
start_noise = 0.0
|
||||
end_noise = 1.0
|
||||
|
||||
# if we're denoising from an existing sequence
|
||||
if denoise_start > 0.0 and resps_list is not None:
|
||||
|
@ -292,13 +299,12 @@ class AR_NAR(Base):
|
|||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
prev_list = [ input_ids ]
|
||||
|
||||
start_temperature = temperature
|
||||
start_noise = 0.0
|
||||
end_noise = 1.0
|
||||
|
||||
null_text = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
||||
null_prom = None
|
||||
|
||||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||
|
||||
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
|
||||
# anneal temperature
|
||||
temperature = start_temperature * (steps_until_x0 / max_steps)
|
||||
|
@ -397,7 +403,7 @@ class AR_NAR(Base):
|
|||
# update scores (conjugated to put the worst scores at the top)
|
||||
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
|
||||
|
||||
if cfg.experimental:
|
||||
if cfg.experimental and max_steps > 0:
|
||||
print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
|
||||
|
||||
return input_ids
|
||||
|
@ -1056,7 +1062,7 @@ def example_usage():
|
|||
if task == "stt":
|
||||
prom = [ task ]
|
||||
else:
|
||||
task = "tts"
|
||||
task = "tts" if random.random() > 0.1 else "len"
|
||||
|
||||
texts.append( text )
|
||||
proms.append( prom )
|
||||
|
|
|
@ -140,6 +140,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
proms_list=batch["proms"],
|
||||
lang_list=batch["lang"],
|
||||
task_list=batch["task"],
|
||||
training=False,
|
||||
)
|
||||
|
||||
if engine.hyper_config.experimental.hf:
|
||||
|
|
Loading…
Reference in New Issue
Block a user