fixes
This commit is contained in:
parent
9def34cd66
commit
48490757da
|
@ -282,7 +282,7 @@ def main():
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
amp=args.amp,
|
amp=args.amp,
|
||||||
|
|
||||||
verbose=False,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not similarities:
|
if not similarities:
|
||||||
|
|
|
@ -188,7 +188,7 @@ def load_engines(training=True, **model_kwargs):
|
||||||
|
|
||||||
# resize modules if I'm doing experiments and can't be assed to manually trim things
|
# resize modules if I'm doing experiments and can't be assed to manually trim things
|
||||||
if cfg.trainer.resize_modules:
|
if cfg.trainer.resize_modules:
|
||||||
uses_stop_token = 1 if model.causal_size > 0 else 0
|
uses_stop_token = 1 if ("ar" in model.capabilities or "len" in model.capabilities) > 0 else 0
|
||||||
keys = [
|
keys = [
|
||||||
("text_emb.weight", model.config.text_tokens ),
|
("text_emb.weight", model.config.text_tokens ),
|
||||||
("tasks_emb.weight", model.config.tasks ),
|
("tasks_emb.weight", model.config.tasks ),
|
||||||
|
|
|
@ -273,6 +273,7 @@ class AR_NAR(Base):
|
||||||
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
|
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
|
||||||
|
|
||||||
_super = super()
|
_super = super()
|
||||||
|
# to-do: allow for batch processing (it should probably work batched anyways)
|
||||||
def demask_sampling( batch_index, seq_len ):
|
def demask_sampling( batch_index, seq_len ):
|
||||||
# overrides
|
# overrides
|
||||||
max_steps = 10
|
max_steps = 10
|
||||||
|
@ -321,17 +322,17 @@ class AR_NAR(Base):
|
||||||
# setup inputs
|
# setup inputs
|
||||||
|
|
||||||
inputs = _super.inputs(
|
inputs = _super.inputs(
|
||||||
text_list=text_list,
|
text_list=[ text_list[batch_index] ] if text_list else None,
|
||||||
proms_list=proms_list,
|
proms_list=[ proms_list[batch_index] ] if proms_list else None,
|
||||||
resps_list=[ input_ids ],
|
resps_list=[ input_ids ],
|
||||||
lang_list=lang_list,
|
lang_list=[ lang_list[batch_index] ] if lang_list else None,
|
||||||
tone_list=tone_list,
|
tone_list=[ tone_list[batch_index] ] if tone_list else None,
|
||||||
time_list=[ timestep ],
|
time_list=[ timestep ],
|
||||||
quant_levels=quant_levels,
|
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
|
||||||
)
|
)
|
||||||
output = _super.forward(
|
output = _super.forward(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
quant_levels=quant_levels,
|
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
|
||||||
#layer_skip_variables=sampling_layer_skip_variables,
|
#layer_skip_variables=sampling_layer_skip_variables,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -342,24 +343,24 @@ class AR_NAR(Base):
|
||||||
text_list=[ null_text ],
|
text_list=[ null_text ],
|
||||||
proms_list=[ null_prom ],
|
proms_list=[ null_prom ],
|
||||||
resps_list=[ input_ids ],
|
resps_list=[ input_ids ],
|
||||||
lang_list=lang_list,
|
lang_list=[ lang_list[batch_index] ] if lang_list else None,
|
||||||
tone_list=tone_list,
|
tone_list=[ tone_list[batch_index] ] if tone_list else None,
|
||||||
time_list=[ timestep ],
|
time_list=[ timestep ],
|
||||||
quant_levels=quant_levels,
|
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
|
||||||
)
|
)
|
||||||
null_output = _super.forward(
|
null_output = _super.forward(
|
||||||
inputs=null_inputs,
|
inputs=null_inputs,
|
||||||
quant_levels=quant_levels,
|
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
|
||||||
#layer_skip_variables=sampling_layer_skip_variables,
|
#layer_skip_variables=sampling_layer_skip_variables,
|
||||||
)
|
)
|
||||||
for logit, null_logits in zip(output.logits, null_output.logits):
|
for logit, null_logit in zip(output.logits, null_output.logits):
|
||||||
logit[-seq_len:] = logit[-seq_len:] + ( logit[-seq_len:] - null_logits[-seq_len:] ) * cfg_strength
|
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||||
|
|
||||||
# sample with sampler settings
|
# sample with sampler settings
|
||||||
filtered_sampled = _super.sample(
|
filtered_sampled = _super.sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
prev_list=prev_list,
|
prev_list=prev_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
|
||||||
|
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
min_temperature=sampling_min_temperature,
|
min_temperature=sampling_min_temperature,
|
||||||
|
@ -375,7 +376,7 @@ class AR_NAR(Base):
|
||||||
unfiltered_sampled = _super.sample(
|
unfiltered_sampled = _super.sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
prev_list=prev_list,
|
prev_list=prev_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=[ quant_levels[batch_index] ] if quant_levels else None,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
# update previous list of tokens
|
# update previous list of tokens
|
||||||
|
@ -732,8 +733,10 @@ class AR_NAR(Base):
|
||||||
filename = "metrics"
|
filename = "metrics"
|
||||||
if sampling_entropix:
|
if sampling_entropix:
|
||||||
filename += f'[entropix]'
|
filename += f'[entropix]'
|
||||||
|
"""
|
||||||
if sampling_layer_skip_exit_layer >= 0:
|
if sampling_layer_skip_exit_layer >= 0:
|
||||||
filename += f'[{sampling_layer_skip_exit_layer+1}]'
|
filename += f'[{sampling_layer_skip_exit_layer+1}]'
|
||||||
|
"""
|
||||||
|
|
||||||
plot_sample_metrics( metrics, filename=f'{filename}.png' )
|
plot_sample_metrics( metrics, filename=f'{filename}.png' )
|
||||||
|
|
||||||
|
|
|
@ -148,13 +148,14 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
elif "len" in engine.hyper_config.capabilities:
|
elif "len" in engine.hyper_config.capabilities:
|
||||||
kwargs = base_kwargs | cfg.evaluation.ar_kwargs
|
kwargs = base_kwargs | cfg.evaluation.ar_kwargs
|
||||||
max_steps = kwargs.pop("max_steps", 500)
|
max_steps = kwargs.pop("max_steps", 500)
|
||||||
len_list = engine( max_steps=5, **kwargs )
|
|
||||||
len_list = [ min( l, max_steps ) for l in len_list ]
|
|
||||||
|
|
||||||
if True:
|
if True:
|
||||||
len_list = [ resp.shape[0] for resp in batch["resps"] ]
|
len_list = [ resp.shape[0] for resp in batch["resps"] ]
|
||||||
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
|
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
|
||||||
kwargs["denoise_start"] = 0.5
|
kwargs["denoise_start"] = 0.5
|
||||||
|
else:
|
||||||
|
len_list = engine( max_steps=5, **kwargs )
|
||||||
|
len_list = [ min( l, max_steps ) for l in len_list ]
|
||||||
|
|
||||||
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
|
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
|
||||||
resps_list = engine( **kwargs, len_list=len_list )
|
resps_list = engine( **kwargs, len_list=len_list )
|
||||||
|
|
Loading…
Reference in New Issue
Block a user