threw in CFG sampling for normal model as well to experiment with
This commit is contained in:
parent
2f56696506
commit
b1f4db39c8
|
@ -323,7 +323,7 @@ class AR_NAR(Base):
|
||||||
# keep unmasked tokens
|
# keep unmasked tokens
|
||||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||||
# update scores (conjugated to put the worst scores at the top)
|
# update scores (conjugated to put the worst scores at the top)
|
||||||
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in unfiltered_sampled.scores ]
|
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
|
||||||
|
|
||||||
if cfg.experimental and max_steps > 0:
|
if cfg.experimental and max_steps > 0:
|
||||||
print( timestep, steps_until_x0, noise_p, resps_list, scores )
|
print( timestep, steps_until_x0, noise_p, resps_list, scores )
|
||||||
|
@ -356,10 +356,12 @@ class AR_NAR(Base):
|
||||||
batch_size = len(resps_list)
|
batch_size = len(resps_list)
|
||||||
|
|
||||||
|
|
||||||
max_levels = sampling_kwargs.get("max_levels", 0)
|
|
||||||
# convert NAR specific args
|
# convert NAR specific args
|
||||||
sampling_kwargs = convert_kwargs( sampling_kwargs, "nar_" )
|
sampling_kwargs = convert_kwargs( sampling_kwargs, "nar_" )
|
||||||
|
|
||||||
|
max_levels = sampling_kwargs.get("max_levels", 0)
|
||||||
|
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
|
||||||
|
|
||||||
if max_levels == 0:
|
if max_levels == 0:
|
||||||
max_levels = self.n_max_levels - 1
|
max_levels = self.n_max_levels - 1
|
||||||
|
|
||||||
|
@ -395,6 +397,9 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
prev_list = resps_list
|
prev_list = resps_list
|
||||||
|
|
||||||
|
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||||
|
null_prom = [ None for _ in range(batch_size) ]
|
||||||
|
|
||||||
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
||||||
level = prev_list[0].shape[-1]
|
level = prev_list[0].shape[-1]
|
||||||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||||
|
@ -421,6 +426,23 @@ class AR_NAR(Base):
|
||||||
)
|
)
|
||||||
logits, state = output.logits, output.state
|
logits, state = output.logits, output.state
|
||||||
|
|
||||||
|
if cfg_strength > 0:
|
||||||
|
null_inputs = super().inputs(
|
||||||
|
text_list=null_text,
|
||||||
|
proms_list=null_prom,
|
||||||
|
resps_list=resps_list,
|
||||||
|
lang_list=lang_list,
|
||||||
|
tone_list=tone_list,
|
||||||
|
quant_levels=quant_levels,
|
||||||
|
)
|
||||||
|
null_output = super().forward(
|
||||||
|
inputs=null_inputs,
|
||||||
|
quant_levels=quant_levels,
|
||||||
|
#layer_skip_variables=sampling_layer_skip_variables,
|
||||||
|
)
|
||||||
|
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
||||||
|
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||||
|
|
||||||
sampled = super().sample(
|
sampled = super().sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
prev_list=prev_list,
|
prev_list=prev_list,
|
||||||
|
@ -465,6 +487,7 @@ class AR_NAR(Base):
|
||||||
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
|
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
|
||||||
|
|
||||||
temperature = sampling_kwargs.get("temperature", 1.0)
|
temperature = sampling_kwargs.get("temperature", 1.0)
|
||||||
|
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
|
||||||
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
|
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
|
||||||
max_duration = sampling_kwargs.get("max_duration", 500)
|
max_duration = sampling_kwargs.get("max_duration", 500)
|
||||||
beam_width = sampling_kwargs.get("beam_width", 0)
|
beam_width = sampling_kwargs.get("beam_width", 0)
|
||||||
|
@ -567,6 +590,9 @@ class AR_NAR(Base):
|
||||||
sequence_list[i] = sequence_list[i][:, 0]
|
sequence_list[i] = sequence_list[i][:, 0]
|
||||||
# start_slice[i] = sequence_list[i].shape[0]
|
# start_slice[i] = sequence_list[i].shape[0]
|
||||||
|
|
||||||
|
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||||
|
null_prom = [ None for _ in range(batch_size) ]
|
||||||
|
|
||||||
# get next in sequence
|
# get next in sequence
|
||||||
for n in trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
for n in trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
||||||
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
|
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
|
||||||
|
@ -591,6 +617,24 @@ class AR_NAR(Base):
|
||||||
#layer_skip_variables=sampling_layer_skip_variables,
|
#layer_skip_variables=sampling_layer_skip_variables,
|
||||||
output_attentions=entropix_sampling,
|
output_attentions=entropix_sampling,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg_strength > 0:
|
||||||
|
null_inputs = super().inputs(
|
||||||
|
text_list=null_text,
|
||||||
|
proms_list=null_prom,
|
||||||
|
resps_list=resps_list,
|
||||||
|
lang_list=lang_list,
|
||||||
|
tone_list=tone_list,
|
||||||
|
quant_levels=quant_levels,
|
||||||
|
)
|
||||||
|
null_output = super().forward(
|
||||||
|
inputs=null_inputs,
|
||||||
|
quant_levels=quant_levels,
|
||||||
|
#layer_skip_variables=sampling_layer_skip_variables,
|
||||||
|
)
|
||||||
|
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
||||||
|
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||||
|
|
||||||
logits, state = output.logits, output.state
|
logits, state = output.logits, output.state
|
||||||
|
|
||||||
sampled = super().sample(
|
sampled = super().sample(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user