lol
This commit is contained in:
parent
9cb0b6901b
commit
9def34cd66
|
@ -508,6 +508,7 @@ def get_task_symmap():
|
||||||
"<eoe>": 7,
|
"<eoe>": 7,
|
||||||
"<stt>": 8,
|
"<stt>": 8,
|
||||||
|
|
||||||
|
"<len>": 0, # fake
|
||||||
"<nse>": 6, # fake
|
"<nse>": 6, # fake
|
||||||
"<cse>": 6, # fake
|
"<cse>": 6, # fake
|
||||||
}
|
}
|
||||||
|
|
|
@ -149,6 +149,9 @@ class AR_NAR(Base):
|
||||||
else:
|
else:
|
||||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||||
|
|
||||||
|
if task == "len":
|
||||||
|
quant_levels[i] = 0
|
||||||
|
|
||||||
# apply CFG (should probably only apply to NAR quant level 0)
|
# apply CFG (should probably only apply to NAR quant level 0)
|
||||||
if task not in text_task + ["len"]:
|
if task not in text_task + ["len"]:
|
||||||
drop_text = False
|
drop_text = False
|
||||||
|
@ -277,6 +280,10 @@ class AR_NAR(Base):
|
||||||
cfg_strength = 1.0
|
cfg_strength = 1.0
|
||||||
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
|
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
|
||||||
sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9
|
sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9
|
||||||
|
|
||||||
|
start_temperature = temperature
|
||||||
|
start_noise = 0.0
|
||||||
|
end_noise = 1.0
|
||||||
|
|
||||||
# if we're denoising from an existing sequence
|
# if we're denoising from an existing sequence
|
||||||
if denoise_start > 0.0 and resps_list is not None:
|
if denoise_start > 0.0 and resps_list is not None:
|
||||||
|
@ -292,13 +299,12 @@ class AR_NAR(Base):
|
||||||
quant_levels = [ level for _ in range(batch_size) ]
|
quant_levels = [ level for _ in range(batch_size) ]
|
||||||
prev_list = [ input_ids ]
|
prev_list = [ input_ids ]
|
||||||
|
|
||||||
start_temperature = temperature
|
|
||||||
start_noise = 0.0
|
|
||||||
end_noise = 1.0
|
|
||||||
|
|
||||||
null_text = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
null_text = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
||||||
null_prom = None
|
null_prom = None
|
||||||
|
|
||||||
|
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||||
|
|
||||||
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
|
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
|
||||||
# anneal temperature
|
# anneal temperature
|
||||||
temperature = start_temperature * (steps_until_x0 / max_steps)
|
temperature = start_temperature * (steps_until_x0 / max_steps)
|
||||||
|
@ -397,7 +403,7 @@ class AR_NAR(Base):
|
||||||
# update scores (conjugated to put the worst scores at the top)
|
# update scores (conjugated to put the worst scores at the top)
|
||||||
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
|
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
|
||||||
|
|
||||||
if cfg.experimental:
|
if cfg.experimental and max_steps > 0:
|
||||||
print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
|
print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
|
||||||
|
|
||||||
return input_ids
|
return input_ids
|
||||||
|
@ -1056,7 +1062,7 @@ def example_usage():
|
||||||
if task == "stt":
|
if task == "stt":
|
||||||
prom = [ task ]
|
prom = [ task ]
|
||||||
else:
|
else:
|
||||||
task = "tts"
|
task = "tts" if random.random() > 0.1 else "len"
|
||||||
|
|
||||||
texts.append( text )
|
texts.append( text )
|
||||||
proms.append( prom )
|
proms.append( prom )
|
||||||
|
|
|
@ -140,6 +140,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
proms_list=batch["proms"],
|
proms_list=batch["proms"],
|
||||||
lang_list=batch["lang"],
|
lang_list=batch["lang"],
|
||||||
task_list=batch["task"],
|
task_list=batch["task"],
|
||||||
|
training=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if engine.hyper_config.experimental.hf:
|
if engine.hyper_config.experimental.hf:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user