haha... (do not create a token dropout/noise mask when not training (this sadly didnt fix NAR-len output))
This commit is contained in:
parent
b09328069e
commit
663f07038d
|
@ -102,6 +102,11 @@ class AR_NAR(Base):
|
|||
if task in text_task:
|
||||
quant_levels[i] = 0 # self.n_resp_levels - 1
|
||||
elif lo <= quant_levels[i] and quant_levels[i] <= hi and random.random() < masking_train_p:
|
||||
# to-do: prioritize lower timesteps over later timesteps
|
||||
# ...except that the masking rate is still tied to the cosine scheduling, which does this already
|
||||
#r = random.random()
|
||||
#p = math.acos(r) / (math.pi * 0.5)
|
||||
#timesteps[i] = 1.0 - clamp(p, 0.0, 1.0)
|
||||
timesteps[i] = random.random()
|
||||
|
||||
# trim resps to only contain all levels below the target level
|
||||
|
@ -237,7 +242,7 @@ class AR_NAR(Base):
|
|||
if start_noise > 0.0 and resps_list is not None:
|
||||
noise_p = math.cos( start_noise * math.pi * 0.5 )
|
||||
mask = [ torch.tensor( [ random.random() < noise_p for _ in range( seq_len ) ], dtype=torch.bool, device=device ) for seq_len in len_list ]
|
||||
resps_list = [ torch.where( mask, self.stop_token, resps[:, 0] ) for seq_len, resps in zip( len_list, resps_list ) ]
|
||||
resps_list = [ torch.where( is_masked, self.stop_token, resps if resps.dim() == 1 else resps[:, 0] ) for is_masked, seq_len, resps in zip( mask, len_list, resps_list ) ]
|
||||
else:
|
||||
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
|
||||
|
||||
|
@ -248,6 +253,7 @@ class AR_NAR(Base):
|
|||
prev_list = resps_list
|
||||
|
||||
for timestep, steps_until_x0 in tqdm(zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))), desc="NAR Masked", disable=disable_tqdm, total=max_steps):
|
||||
annealing = (steps_until_x0 / max_steps)
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
# pick the worst scoring tokens to mask off
|
||||
|
@ -293,7 +299,7 @@ class AR_NAR(Base):
|
|||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * (cfg_strength * timestep)
|
||||
|
||||
# sample with sampler settings
|
||||
filtered_sampled = super().sample(
|
||||
|
@ -301,7 +307,7 @@ class AR_NAR(Base):
|
|||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=temperature * (steps_until_x0 / max_steps),
|
||||
temperature=temperature * annealing,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
|
@ -319,8 +325,8 @@ class AR_NAR(Base):
|
|||
|
||||
# sample with gumbelnoise
|
||||
# This actually lobotomizes things
|
||||
#sampled_ids = [ gumbel_sample( logits, temperature=temperature * (steps_until_x0 / max_steps), dim=-1 ) for logits in filtered_sampled.logits[0] ]
|
||||
sampled_ids = filtered_sampled[0]
|
||||
#sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits[0] ]
|
||||
sampled_ids = filtered_sampled.ids
|
||||
|
||||
# keep unmasked tokens
|
||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||
|
@ -362,24 +368,9 @@ class AR_NAR(Base):
|
|||
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||
|
||||
sampled = super().sample(
|
||||
logits=logits,
|
||||
prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
|
||||
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
|
||||
)
|
||||
|
||||
# remove stop token
|
||||
resps_list = [self._prune(r, self.stop_token) for i, r in enumerate(resps_list)]
|
||||
|
||||
# get how much we need to slice from the end
|
||||
slice_lengths = [ sequence.shape[-1] for sequence in resps_list ]
|
||||
# -1 for the stop token
|
||||
logits = [ logit[-length-1:-1] for logit, length in zip(logits, slice_lengths) ]
|
||||
logits = [ logit[-length-1:-1] for logit, length in zip(logits, len_list) ]
|
||||
# greedy sample from the sequence
|
||||
refined_list = [ logit.argmax(dim=-1) for logit in logits ]
|
||||
# to-do: compare scores
|
||||
# set the "refined" list as the output
|
||||
resps_list = refined_list
|
||||
|
||||
if cfg.experimental and max_steps > 0:
|
||||
print( timestep, steps_until_x0, noise_p, resps_list, scores )
|
||||
|
@ -446,6 +437,19 @@ class AR_NAR(Base):
|
|||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
"""
|
||||
resps_list = self.forward_nar_masked(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
task_list=task_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
**(sampling_kwargs|{"denoise_start": 0.5}),
|
||||
)
|
||||
"""
|
||||
|
||||
# expand if given a raw 1D tensor
|
||||
for i, resp in enumerate(resps_list):
|
||||
if resp.dim() == 1:
|
||||
|
@ -508,7 +512,7 @@ class AR_NAR(Base):
|
|||
**(sampling_kwargs | {"temperature": 0.0}),
|
||||
)
|
||||
|
||||
resps_list = sampled[0]
|
||||
resps_list = sampled.ids
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
||||
return prev_list
|
||||
|
@ -703,7 +707,7 @@ class AR_NAR(Base):
|
|||
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
|
||||
)
|
||||
|
||||
r = sampled[0]
|
||||
ids = sampled.ids
|
||||
|
||||
if cfg.experimental:
|
||||
if sampled.entropy:
|
||||
|
@ -730,12 +734,12 @@ class AR_NAR(Base):
|
|||
scores = [ scores[i] + score for i, score in enumerate(s) ]
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
for i, token in enumerate(ids):
|
||||
task = task_list[i]
|
||||
stop_token = audio_stop_token if task not in text_task else text_stop_token
|
||||
if stop_token in ri:
|
||||
if stop_token in token:
|
||||
stopped[i] = True
|
||||
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
|
||||
sequence_list[i] = torch.cat([sequence_list[i], token.to(device)])
|
||||
|
||||
# stop token found
|
||||
# stopped |= r == stop_token
|
||||
|
|
|
@ -39,7 +39,7 @@ from ..data import get_task_symmap
|
|||
|
||||
# these seem more elegant than a dict
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer'])
|
||||
Sampled = namedtuple('Sampled', ['out', 'logits', 'scores', 'entropy'])
|
||||
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
|
||||
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
||||
|
||||
"""
|
||||
|
@ -1028,8 +1028,8 @@ class Base(nn.Module):
|
|||
if resps_list is not None and resps_list[i] is not None:
|
||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
|
||||
# store dropout mask
|
||||
if timestep is not None:
|
||||
# store dropout mask (if training)
|
||||
if timestep is not None and self.training:
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) )
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
|
||||
|
@ -1558,6 +1558,10 @@ class Base(nn.Module):
|
|||
|
||||
return early
|
||||
|
||||
# derive quant levels from inputs if not provided
|
||||
if quant_levels is None:
|
||||
quant_levels = self.get_input( inputs, "quant_level" )
|
||||
|
||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||
|
||||
x, mask = list_to_tensor(x_list)
|
||||
|
@ -1680,7 +1684,7 @@ class Base(nn.Module):
|
|||
self,
|
||||
logits: list[Tensor], # logit scores
|
||||
prev_list: list[Tensor] | None = None, # previous tokens
|
||||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
quant_levels: int | list[int] | Tensor | None = None, # to-do: derive this from the prev_list
|
||||
**sampling_kwargs,
|
||||
):
|
||||
# yikes
|
||||
|
@ -1767,12 +1771,7 @@ class Base(nn.Module):
|
|||
|
||||
# perform repetition penalizing
|
||||
if prev_list is not None and repetition_penalty != 1.0:
|
||||
# penalize non-autoregressively
|
||||
if quant_levels is not None:
|
||||
logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
# penalize autoregressively
|
||||
else:
|
||||
logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
|
||||
# (AR) perform length penalizing
|
||||
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
|
||||
|
|
|
@ -428,7 +428,7 @@ with ui:
|
|||
layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
|
||||
layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=0.0, minimum=0.0, maximum=3.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
|
||||
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=0.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale")
|
||||
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||
with gr.Tab("Sampler Settings"):
|
||||
with gr.Row():
|
||||
|
@ -437,7 +437,7 @@ with ui:
|
|||
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
|
||||
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||
layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||
with gr.Row():
|
||||
|
|
Loading…
Reference in New Issue
Block a user