threw in CFG sampling for normal model as well to experiment with

This commit is contained in:
mrq 2024-11-11 20:27:38 -06:00
parent 2f56696506
commit b1f4db39c8

View File

@ -323,7 +323,7 @@ class AR_NAR(Base):
# keep unmasked tokens
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
# update scores (conjugated to put the worst scores at the top)
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in unfiltered_sampled.scores ]
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
if cfg.experimental and max_steps > 0:
print( timestep, steps_until_x0, noise_p, resps_list, scores )
@ -356,10 +356,12 @@ class AR_NAR(Base):
batch_size = len(resps_list)
max_levels = sampling_kwargs.get("max_levels", 0)
# convert NAR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "nar_" )
max_levels = sampling_kwargs.get("max_levels", 0)
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
if max_levels == 0:
max_levels = self.n_max_levels - 1
@ -395,6 +397,9 @@ class AR_NAR(Base):
prev_list = resps_list
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
level = prev_list[0].shape[-1]
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
@ -421,6 +426,23 @@ class AR_NAR(Base):
)
logits, state = output.logits, output.state
if cfg_strength > 0:
null_inputs = super().inputs(
text_list=null_text,
proms_list=null_prom,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
null_output = super().forward(
inputs=null_inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
sampled = super().sample(
logits=logits,
prev_list=prev_list,
@ -465,6 +487,7 @@ class AR_NAR(Base):
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
temperature = sampling_kwargs.get("temperature", 1.0)
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
max_duration = sampling_kwargs.get("max_duration", 500)
beam_width = sampling_kwargs.get("beam_width", 0)
@ -567,6 +590,9 @@ class AR_NAR(Base):
sequence_list[i] = sequence_list[i][:, 0]
# start_slice[i] = sequence_list[i].shape[0]
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
# get next in sequence
for n in trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
@ -591,6 +617,24 @@ class AR_NAR(Base):
#layer_skip_variables=sampling_layer_skip_variables,
output_attentions=entropix_sampling,
)
if cfg_strength > 0:
null_inputs = super().inputs(
text_list=null_text,
proms_list=null_prom,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
null_output = super().forward(
inputs=null_inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
logits, state = output.logits, output.state
sampled = super().sample(