haha... (do not create a token dropout/noise mask when not training (this sadly didnt fix NAR-len output))

This commit is contained in:
mrq 2024-11-12 16:41:58 -06:00
parent b09328069e
commit 663f07038d
3 changed files with 41 additions and 38 deletions

View File

@ -102,6 +102,11 @@ class AR_NAR(Base):
if task in text_task:
quant_levels[i] = 0 # self.n_resp_levels - 1
elif lo <= quant_levels[i] and quant_levels[i] <= hi and random.random() < masking_train_p:
# to-do: prioritize lower timesteps over later timesteps
# ...except that the masking rate is still tied to the cosine scheduling, which does this already
#r = random.random()
#p = math.acos(r) / (math.pi * 0.5)
#timesteps[i] = 1.0 - clamp(p, 0.0, 1.0)
timesteps[i] = random.random()
# trim resps to only contain all levels below the target level
@ -237,7 +242,7 @@ class AR_NAR(Base):
if start_noise > 0.0 and resps_list is not None:
noise_p = math.cos( start_noise * math.pi * 0.5 )
mask = [ torch.tensor( [ random.random() < noise_p for _ in range( seq_len ) ], dtype=torch.bool, device=device ) for seq_len in len_list ]
resps_list = [ torch.where( mask, self.stop_token, resps[:, 0] ) for seq_len, resps in zip( len_list, resps_list ) ]
resps_list = [ torch.where( is_masked, self.stop_token, resps if resps.dim() == 1 else resps[:, 0] ) for is_masked, seq_len, resps in zip( mask, len_list, resps_list ) ]
else:
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
@ -248,6 +253,7 @@ class AR_NAR(Base):
prev_list = resps_list
for timestep, steps_until_x0 in tqdm(zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))), desc="NAR Masked", disable=disable_tqdm, total=max_steps):
annealing = (steps_until_x0 / max_steps)
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
# pick the worst scoring tokens to mask off
@ -293,7 +299,7 @@ class AR_NAR(Base):
#layer_skip_variables=sampling_layer_skip_variables,
)
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * (cfg_strength * timestep)
# sample with sampler settings
filtered_sampled = super().sample(
@ -301,7 +307,7 @@ class AR_NAR(Base):
prev_list=prev_list,
quant_levels=quant_levels,
temperature=temperature * (steps_until_x0 / max_steps),
temperature=temperature * annealing,
**sampling_kwargs,
)
@ -319,8 +325,8 @@ class AR_NAR(Base):
# sample with gumbelnoise
# This actually lobotomizes things
#sampled_ids = [ gumbel_sample( logits, temperature=temperature * (steps_until_x0 / max_steps), dim=-1 ) for logits in filtered_sampled.logits[0] ]
sampled_ids = filtered_sampled[0]
#sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits[0] ]
sampled_ids = filtered_sampled.ids
# keep unmasked tokens
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
@ -362,24 +368,9 @@ class AR_NAR(Base):
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
sampled = super().sample(
logits=logits,
prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
)
# remove stop token
resps_list = [self._prune(r, self.stop_token) for i, r in enumerate(resps_list)]
# get how much we need to slice from the end
slice_lengths = [ sequence.shape[-1] for sequence in resps_list ]
# -1 for the stop token
logits = [ logit[-length-1:-1] for logit, length in zip(logits, slice_lengths) ]
logits = [ logit[-length-1:-1] for logit, length in zip(logits, len_list) ]
# greedy sample from the sequence
refined_list = [ logit.argmax(dim=-1) for logit in logits ]
# to-do: compare scores
# set the "refined" list as the output
resps_list = refined_list
if cfg.experimental and max_steps > 0:
print( timestep, steps_until_x0, noise_p, resps_list, scores )
@ -446,6 +437,19 @@ class AR_NAR(Base):
**sampling_kwargs,
)
"""
resps_list = self.forward_nar_masked(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
task_list=task_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
**(sampling_kwargs|{"denoise_start": 0.5}),
)
"""
# expand if given a raw 1D tensor
for i, resp in enumerate(resps_list):
if resp.dim() == 1:
@ -508,7 +512,7 @@ class AR_NAR(Base):
**(sampling_kwargs | {"temperature": 0.0}),
)
resps_list = sampled[0]
resps_list = sampled.ids
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
return prev_list
@ -703,7 +707,7 @@ class AR_NAR(Base):
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
)
r = sampled[0]
ids = sampled.ids
if cfg.experimental:
if sampled.entropy:
@ -730,12 +734,12 @@ class AR_NAR(Base):
scores = [ scores[i] + score for i, score in enumerate(s) ]
# append tokens
for i, ri in enumerate(r):
for i, token in enumerate(ids):
task = task_list[i]
stop_token = audio_stop_token if task not in text_task else text_stop_token
if stop_token in ri:
if stop_token in token:
stopped[i] = True
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
sequence_list[i] = torch.cat([sequence_list[i], token.to(device)])
# stop token found
# stopped |= r == stop_token

View File

@ -39,7 +39,7 @@ from ..data import get_task_symmap
# these seem more elegant than a dict
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer'])
Sampled = namedtuple('Sampled', ['out', 'logits', 'scores', 'entropy'])
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
LossStats = namedtuple('LossStats', ['loss', 'stats'])
"""
@ -1028,8 +1028,8 @@ class Base(nn.Module):
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
# store dropout mask
if timestep is not None:
# store dropout mask (if training)
if timestep is not None and self.training:
dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) )
inputs[i].append( ("dropout_mask", dropout_mask ) )
@ -1558,6 +1558,10 @@ class Base(nn.Module):
return early
# derive quant levels from inputs if not provided
if quant_levels is None:
quant_levels = self.get_input( inputs, "quant_level" )
x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, mask = list_to_tensor(x_list)
@ -1680,7 +1684,7 @@ class Base(nn.Module):
self,
logits: list[Tensor], # logit scores
prev_list: list[Tensor] | None = None, # previous tokens
quant_levels: int | list[int] | Tensor | None = None,
quant_levels: int | list[int] | Tensor | None = None, # to-do: derive this from the prev_list
**sampling_kwargs,
):
# yikes
@ -1767,12 +1771,7 @@ class Base(nn.Module):
# perform repetition penalizing
if prev_list is not None and repetition_penalty != 1.0:
# penalize non-autoregressively
if quant_levels is not None:
logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# penalize autoregressively
else:
logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# (AR) perform length penalizing
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:

View File

@ -428,7 +428,7 @@ with ui:
layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row():
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=0.0, minimum=0.0, maximum=3.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=0.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):
with gr.Row():
@ -437,7 +437,7 @@ with ui:
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
with gr.Row():
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
with gr.Row():