fixes
This commit is contained in:
parent
9def34cd66
commit
48490757da
|
@ -282,7 +282,7 @@ def main():
|
|||
dtype=args.dtype,
|
||||
amp=args.amp,
|
||||
|
||||
verbose=False,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
if not similarities:
|
||||
|
|
|
@ -188,7 +188,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
|
||||
# resize modules if I'm doing experiments and can't be assed to manually trim things
|
||||
if cfg.trainer.resize_modules:
|
||||
uses_stop_token = 1 if model.causal_size > 0 else 0
|
||||
uses_stop_token = 1 if ("ar" in model.capabilities or "len" in model.capabilities) > 0 else 0
|
||||
keys = [
|
||||
("text_emb.weight", model.config.text_tokens ),
|
||||
("tasks_emb.weight", model.config.tasks ),
|
||||
|
|
|
@ -273,6 +273,7 @@ class AR_NAR(Base):
|
|||
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
|
||||
|
||||
_super = super()
|
||||
# to-do: allow for batch processing (it should probably work batched anyways)
|
||||
def demask_sampling( batch_index, seq_len ):
|
||||
# overrides
|
||||
max_steps = 10
|
||||
|
@ -321,17 +322,17 @@ class AR_NAR(Base):
|
|||
# setup inputs
|
||||
|
||||
inputs = _super.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
text_list=[ text_list[batch_index] ] if text_list else None,
|
||||
proms_list=[ proms_list[batch_index] ] if proms_list else None,
|
||||
resps_list=[ input_ids ],
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
lang_list=[ lang_list[batch_index] ] if lang_list else None,
|
||||
tone_list=[ tone_list[batch_index] ] if tone_list else None,
|
||||
time_list=[ timestep ],
|
||||
quant_levels=quant_levels,
|
||||
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
|
||||
)
|
||||
output = _super.forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
|
||||
|
@ -342,24 +343,24 @@ class AR_NAR(Base):
|
|||
text_list=[ null_text ],
|
||||
proms_list=[ null_prom ],
|
||||
resps_list=[ input_ids ],
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
lang_list=[ lang_list[batch_index] ] if lang_list else None,
|
||||
tone_list=[ tone_list[batch_index] ] if tone_list else None,
|
||||
time_list=[ timestep ],
|
||||
quant_levels=quant_levels,
|
||||
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
|
||||
)
|
||||
null_output = _super.forward(
|
||||
inputs=null_inputs,
|
||||
quant_levels=quant_levels,
|
||||
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
for logit, null_logits in zip(output.logits, null_output.logits):
|
||||
logit[-seq_len:] = logit[-seq_len:] + ( logit[-seq_len:] - null_logits[-seq_len:] ) * cfg_strength
|
||||
for logit, null_logit in zip(output.logits, null_output.logits):
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||
|
||||
# sample with sampler settings
|
||||
filtered_sampled = _super.sample(
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
|
||||
|
||||
temperature=temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
|
@ -375,7 +376,7 @@ class AR_NAR(Base):
|
|||
unfiltered_sampled = _super.sample(
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
|
||||
temperature=0.0,
|
||||
)
|
||||
# update previous list of tokens
|
||||
|
@ -732,8 +733,10 @@ class AR_NAR(Base):
|
|||
filename = "metrics"
|
||||
if sampling_entropix:
|
||||
filename += f'[entropix]'
|
||||
"""
|
||||
if sampling_layer_skip_exit_layer >= 0:
|
||||
filename += f'[{sampling_layer_skip_exit_layer+1}]'
|
||||
"""
|
||||
|
||||
plot_sample_metrics( metrics, filename=f'{filename}.png' )
|
||||
|
||||
|
|
|
@ -148,13 +148,14 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
elif "len" in engine.hyper_config.capabilities:
|
||||
kwargs = base_kwargs | cfg.evaluation.ar_kwargs
|
||||
max_steps = kwargs.pop("max_steps", 500)
|
||||
len_list = engine( max_steps=5, **kwargs )
|
||||
len_list = [ min( l, max_steps ) for l in len_list ]
|
||||
|
||||
if True:
|
||||
len_list = [ resp.shape[0] for resp in batch["resps"] ]
|
||||
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
|
||||
kwargs["denoise_start"] = 0.5
|
||||
else:
|
||||
len_list = engine( max_steps=5, **kwargs )
|
||||
len_list = [ min( l, max_steps ) for l in len_list ]
|
||||
|
||||
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
|
||||
resps_list = engine( **kwargs, len_list=len_list )
|
||||
|
|
Loading…
Reference in New Issue
Block a user