This commit is contained in:
mrq 2024-11-10 20:37:50 -06:00
parent 9def34cd66
commit 48490757da
4 changed files with 22 additions and 18 deletions

View File

@ -282,7 +282,7 @@ def main():
dtype=args.dtype,
amp=args.amp,
verbose=False,
verbose=True,
)
if not similarities:

View File

@ -188,7 +188,7 @@ def load_engines(training=True, **model_kwargs):
# resize modules if I'm doing experiments and can't be assed to manually trim things
if cfg.trainer.resize_modules:
uses_stop_token = 1 if model.causal_size > 0 else 0
uses_stop_token = 1 if ("ar" in model.capabilities or "len" in model.capabilities) > 0 else 0
keys = [
("text_emb.weight", model.config.text_tokens ),
("tasks_emb.weight", model.config.tasks ),

View File

@ -273,6 +273,7 @@ class AR_NAR(Base):
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
_super = super()
# to-do: allow for batch processing (it should probably work batched anyways)
def demask_sampling( batch_index, seq_len ):
# overrides
max_steps = 10
@ -321,17 +322,17 @@ class AR_NAR(Base):
# setup inputs
inputs = _super.inputs(
text_list=text_list,
proms_list=proms_list,
text_list=[ text_list[batch_index] ] if text_list else None,
proms_list=[ proms_list[batch_index] ] if proms_list else None,
resps_list=[ input_ids ],
lang_list=lang_list,
tone_list=tone_list,
lang_list=[ lang_list[batch_index] ] if lang_list else None,
tone_list=[ tone_list[batch_index] ] if tone_list else None,
time_list=[ timestep ],
quant_levels=quant_levels,
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
)
output = _super.forward(
inputs=inputs,
quant_levels=quant_levels,
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
#layer_skip_variables=sampling_layer_skip_variables,
)
@ -342,24 +343,24 @@ class AR_NAR(Base):
text_list=[ null_text ],
proms_list=[ null_prom ],
resps_list=[ input_ids ],
lang_list=lang_list,
tone_list=tone_list,
lang_list=[ lang_list[batch_index] ] if lang_list else None,
tone_list=[ tone_list[batch_index] ] if tone_list else None,
time_list=[ timestep ],
quant_levels=quant_levels,
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
)
null_output = _super.forward(
inputs=null_inputs,
quant_levels=quant_levels,
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
#layer_skip_variables=sampling_layer_skip_variables,
)
for logit, null_logits in zip(output.logits, null_output.logits):
logit[-seq_len:] = logit[-seq_len:] + ( logit[-seq_len:] - null_logits[-seq_len:] ) * cfg_strength
for logit, null_logit in zip(output.logits, null_output.logits):
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
# sample with sampler settings
filtered_sampled = _super.sample(
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
temperature=temperature,
min_temperature=sampling_min_temperature,
@ -375,7 +376,7 @@ class AR_NAR(Base):
unfiltered_sampled = _super.sample(
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
temperature=0.0,
)
# update previous list of tokens
@ -732,8 +733,10 @@ class AR_NAR(Base):
filename = "metrics"
if sampling_entropix:
filename += f'[entropix]'
"""
if sampling_layer_skip_exit_layer >= 0:
filename += f'[{sampling_layer_skip_exit_layer+1}]'
"""
plot_sample_metrics( metrics, filename=f'{filename}.png' )

View File

@ -148,13 +148,14 @@ def run_eval(engines, eval_name, dl, args=None):
elif "len" in engine.hyper_config.capabilities:
kwargs = base_kwargs | cfg.evaluation.ar_kwargs
max_steps = kwargs.pop("max_steps", 500)
len_list = engine( max_steps=5, **kwargs )
len_list = [ min( l, max_steps ) for l in len_list ]
if True:
len_list = [ resp.shape[0] for resp in batch["resps"] ]
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
kwargs["denoise_start"] = 0.5
else:
len_list = engine( max_steps=5, **kwargs )
len_list = [ min( l, max_steps ) for l in len_list ]
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
resps_list = engine( **kwargs, len_list=len_list )