threw in CFG sampling for normal model as well to experiment with
This commit is contained in:
parent
2f56696506
commit
b1f4db39c8
|
@ -323,7 +323,7 @@ class AR_NAR(Base):
|
|||
# keep unmasked tokens
|
||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||
# update scores (conjugated to put the worst scores at the top)
|
||||
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in unfiltered_sampled.scores ]
|
||||
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
|
||||
|
||||
if cfg.experimental and max_steps > 0:
|
||||
print( timestep, steps_until_x0, noise_p, resps_list, scores )
|
||||
|
@ -356,10 +356,12 @@ class AR_NAR(Base):
|
|||
batch_size = len(resps_list)
|
||||
|
||||
|
||||
max_levels = sampling_kwargs.get("max_levels", 0)
|
||||
# convert NAR specific args
|
||||
sampling_kwargs = convert_kwargs( sampling_kwargs, "nar_" )
|
||||
|
||||
max_levels = sampling_kwargs.get("max_levels", 0)
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
|
||||
|
||||
if max_levels == 0:
|
||||
max_levels = self.n_max_levels - 1
|
||||
|
||||
|
@ -395,6 +397,9 @@ class AR_NAR(Base):
|
|||
|
||||
prev_list = resps_list
|
||||
|
||||
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||
null_prom = [ None for _ in range(batch_size) ]
|
||||
|
||||
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
||||
level = prev_list[0].shape[-1]
|
||||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||
|
@ -421,6 +426,23 @@ class AR_NAR(Base):
|
|||
)
|
||||
logits, state = output.logits, output.state
|
||||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
text_list=null_text,
|
||||
proms_list=null_prom,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
null_output = super().forward(
|
||||
inputs=null_inputs,
|
||||
quant_levels=quant_levels,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||
|
||||
sampled = super().sample(
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
|
@ -465,6 +487,7 @@ class AR_NAR(Base):
|
|||
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
|
||||
|
||||
temperature = sampling_kwargs.get("temperature", 1.0)
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
|
||||
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
|
||||
max_duration = sampling_kwargs.get("max_duration", 500)
|
||||
beam_width = sampling_kwargs.get("beam_width", 0)
|
||||
|
@ -567,6 +590,9 @@ class AR_NAR(Base):
|
|||
sequence_list[i] = sequence_list[i][:, 0]
|
||||
# start_slice[i] = sequence_list[i].shape[0]
|
||||
|
||||
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||
null_prom = [ None for _ in range(batch_size) ]
|
||||
|
||||
# get next in sequence
|
||||
for n in trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
||||
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
|
||||
|
@ -591,6 +617,24 @@ class AR_NAR(Base):
|
|||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
output_attentions=entropix_sampling,
|
||||
)
|
||||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
text_list=null_text,
|
||||
proms_list=null_prom,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
null_output = super().forward(
|
||||
inputs=null_inputs,
|
||||
quant_levels=quant_levels,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||
|
||||
logits, state = output.logits, output.state
|
||||
|
||||
sampled = super().sample(
|
||||
|
|
Loading…
Reference in New Issue
Block a user