diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index d8ed04e..f0a6e87 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -102,6 +102,11 @@ class AR_NAR(Base): if task in text_task: quant_levels[i] = 0 # self.n_resp_levels - 1 elif lo <= quant_levels[i] and quant_levels[i] <= hi and random.random() < masking_train_p: + # to-do: prioritize lower timesteps over later timesteps + # ...except that the masking rate is still tied to the cosine scheduling, which does this already + #r = random.random() + #p = math.acos(r) / (math.pi * 0.5) + #timesteps[i] = 1.0 - clamp(p, 0.0, 1.0) timesteps[i] = random.random() # trim resps to only contain all levels below the target level @@ -237,7 +242,7 @@ class AR_NAR(Base): if start_noise > 0.0 and resps_list is not None: noise_p = math.cos( start_noise * math.pi * 0.5 ) mask = [ torch.tensor( [ random.random() < noise_p for _ in range( seq_len ) ], dtype=torch.bool, device=device ) for seq_len in len_list ] - resps_list = [ torch.where( mask, self.stop_token, resps[:, 0] ) for seq_len, resps in zip( len_list, resps_list ) ] + resps_list = [ torch.where( is_masked, self.stop_token, resps if resps.dim() == 1 else resps[:, 0] ) for is_masked, seq_len, resps in zip( mask, len_list, resps_list ) ] else: resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ] @@ -248,6 +253,7 @@ class AR_NAR(Base): prev_list = resps_list for timestep, steps_until_x0 in tqdm(zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))), desc="NAR Masked", disable=disable_tqdm, total=max_steps): + annealing = (steps_until_x0 / max_steps) # get noise level, per cosine scheduling noise_p = math.cos( timestep * math.pi * 0.5 ) # pick the worst scoring tokens to mask off @@ -293,7 +299,7 @@ class AR_NAR(Base): #layer_skip_variables=sampling_layer_skip_variables, ) for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits): - logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength + logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * (cfg_strength * timestep) # sample with sampler settings filtered_sampled = super().sample( @@ -301,7 +307,7 @@ class AR_NAR(Base): prev_list=prev_list, quant_levels=quant_levels, - temperature=temperature * (steps_until_x0 / max_steps), + temperature=temperature * annealing, **sampling_kwargs, ) @@ -319,8 +325,8 @@ class AR_NAR(Base): # sample with gumbelnoise # This actually lobotomizes things - #sampled_ids = [ gumbel_sample( logits, temperature=temperature * (steps_until_x0 / max_steps), dim=-1 ) for logits in filtered_sampled.logits[0] ] - sampled_ids = filtered_sampled[0] + #sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits[0] ] + sampled_ids = filtered_sampled.ids # keep unmasked tokens resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] @@ -362,24 +368,9 @@ class AR_NAR(Base): for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits): logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength - sampled = super().sample( - logits=logits, - prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ], - **(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}), - ) - - # remove stop token - resps_list = [self._prune(r, self.stop_token) for i, r in enumerate(resps_list)] - - # get how much we need to slice from the end - slice_lengths = [ sequence.shape[-1] for sequence in resps_list ] - # -1 for the stop token - logits = [ logit[-length-1:-1] for logit, length in zip(logits, slice_lengths) ] + logits = [ logit[-length-1:-1] for logit, length in zip(logits, len_list) ] # greedy sample from the sequence refined_list = [ logit.argmax(dim=-1) for logit in logits ] - # to-do: compare scores - # set the "refined" list as the output - resps_list = refined_list if cfg.experimental and max_steps > 0: print( timestep, steps_until_x0, noise_p, resps_list, scores ) @@ -446,6 +437,19 @@ class AR_NAR(Base): **sampling_kwargs, ) + """ + resps_list = self.forward_nar_masked( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + task_list=task_list, + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + **(sampling_kwargs|{"denoise_start": 0.5}), + ) + """ + # expand if given a raw 1D tensor for i, resp in enumerate(resps_list): if resp.dim() == 1: @@ -508,7 +512,7 @@ class AR_NAR(Base): **(sampling_kwargs | {"temperature": 0.0}), ) - resps_list = sampled[0] + resps_list = sampled.ids prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] return prev_list @@ -703,7 +707,7 @@ class AR_NAR(Base): **(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}), ) - r = sampled[0] + ids = sampled.ids if cfg.experimental: if sampled.entropy: @@ -730,12 +734,12 @@ class AR_NAR(Base): scores = [ scores[i] + score for i, score in enumerate(s) ] # append tokens - for i, ri in enumerate(r): + for i, token in enumerate(ids): task = task_list[i] stop_token = audio_stop_token if task not in text_task else text_stop_token - if stop_token in ri: + if stop_token in token: stopped[i] = True - sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)]) + sequence_list[i] = torch.cat([sequence_list[i], token.to(device)]) # stop token found # stopped |= r == stop_token diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 809773d..dc26f7a 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -39,7 +39,7 @@ from ..data import get_task_symmap # these seem more elegant than a dict Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer']) -Sampled = namedtuple('Sampled', ['out', 'logits', 'scores', 'entropy']) +Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy']) LossStats = namedtuple('LossStats', ['loss', 'stats']) """ @@ -1028,8 +1028,8 @@ class Base(nn.Module): if resps_list is not None and resps_list[i] is not None: inputs[i].append( ( "resp", resps_list[i] ) ) - # store dropout mask - if timestep is not None: + # store dropout mask (if training) + if timestep is not None and self.training: dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) ) inputs[i].append( ("dropout_mask", dropout_mask ) ) @@ -1558,6 +1558,10 @@ class Base(nn.Module): return early + # derive quant levels from inputs if not provided + if quant_levels is None: + quant_levels = self.get_input( inputs, "quant_level" ) + x_list = self.inputs_to_embeddings( inputs, quant_levels ) x, mask = list_to_tensor(x_list) @@ -1680,7 +1684,7 @@ class Base(nn.Module): self, logits: list[Tensor], # logit scores prev_list: list[Tensor] | None = None, # previous tokens - quant_levels: int | list[int] | Tensor | None = None, + quant_levels: int | list[int] | Tensor | None = None, # to-do: derive this from the prev_list **sampling_kwargs, ): # yikes @@ -1767,12 +1771,7 @@ class Base(nn.Module): # perform repetition penalizing if prev_list is not None and repetition_penalty != 1.0: - # penalize non-autoregressively - if quant_levels is not None: - logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] - # penalize autoregressively - else: - logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] + logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] # (AR) perform length penalizing if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0: diff --git a/vall_e/webui.py b/vall_e/webui.py index 4c2b2db..8d951df 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -428,7 +428,7 @@ with ui: layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)") layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") with gr.Row(): - layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=0.0, minimum=0.0, maximum=3.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale") + layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=0.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale") layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") with gr.Tab("Sampler Settings"): with gr.Row(): @@ -437,7 +437,7 @@ with ui: layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P") layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") with gr.Row(): - layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") + layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=0.0, maximum=5.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") with gr.Row():