made greedy AR sampling viable (and preferable), with caveats (per comment in vall_e.models.ar_nar)

This commit is contained in:
mrq 2024-10-18 16:55:00 -05:00
parent 07f4935a75
commit fc8dfd8617
6 changed files with 31 additions and 13 deletions

View File

@ -234,7 +234,7 @@ class TTS():
if "nar" in engine.hyper_config.capabilities:
model_nar = engine.module
set_seed(seed)
seed = set_seed(seed)
if task == "stt":
resp = self.encode_audio( references )

View File

@ -173,7 +173,6 @@ class AR_NAR(Base):
else:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
@ -268,6 +267,14 @@ class AR_NAR(Base):
scores = [ 1.0 ] * sampling_beam_width
entropies = []
# ick
low_temperature = sampling_repetition_penalty == 1.0 and sampling_temperature < 0.5
low_temperature_range = cfg.dataset.frames_per_second * 3
original_sampling_temperature = sampling_temperature
original_sampling_repetition_penalty = sampling_repetition_penalty
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
for i, sequence in enumerate( sequence_list ):
# add <bos> to text for STT
if task_list[i] in text_task:
@ -284,6 +291,17 @@ class AR_NAR(Base):
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
# greedy sampling in the AR *does* work, but requires some quasi-exotic sampling to work around the initial burst of garbage from polluting the rest of the sequence
# naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio
# however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled
# to-do: tune these values, maybe have it factor based on confidence scores or something
# to-do: see if instead just prefixing with blank audio overcomes the initla noise anyways
if low_temperature:
enabled = n < low_temperature_range
sampling_repetition_penalty = 1.35 if enabled else original_sampling_repetition_penalty
sampling_repetition_penalty_decay = 0.5 if enabled else original_sampling_repetition_penalty_decay
sampling_temperature = original_sampling_temperature if enabled else 1.0
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,

View File

@ -114,7 +114,7 @@ except Exception as e:
# to-do: find a better way to query for if there's available kernels since these return true regardless
if torch.backends.cuda.flash_sdp_enabled():
AVAILABLE_ATTENTIONS.append("flash")
AVAILABLE_ATTENTIONS.append("flash_(sdpa)")
if torch.backends.cuda.mem_efficient_sdp_enabled():
AVAILABLE_ATTENTIONS.append("mem_efficient")
@ -141,7 +141,7 @@ class LlamaAttention_Adapted(LlamaAttention):
self.mode = torch.nn.attention.SDPBackend.MATH
elif self.mode == "mem_efficient":
self.mode = torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
elif self.mode == "flash":
elif self.mode == "flash_(sdpa)":
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
elif self.mode == "cudnn":
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION

View File

@ -1535,12 +1535,6 @@ class Base(nn.Module):
if quant_levels is None and "len" in self.capabilities:
logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, prev_list) ) ]
# argmax instead
if temperature <= 0.0:
res = [ logit.argmax(dim=1) for logit in logits ]
scores = None
return Sampled(res, scores, entropy)
# perform repetition penalizing
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
# to-do: figure out a faster way to handle tolist()
@ -1562,7 +1556,7 @@ class Base(nn.Module):
# epsilon float comparison because I don't trust Python
if abs(temperature - min_temperature) >= 0.001:
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
else:
elif temperature > 0.0:
logits = [ logit / temperature for logit in logits ]
# do DRY sampling
@ -1585,6 +1579,10 @@ class Base(nn.Module):
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
# basic sampling
else:
res = [ Categorical(logits=logit).sample() for logit in logits ]
# argmax instead
if temperature <= 0.0:
res = [ logit.argmax(dim=1) for logit in logits ]
else:
res = [ Categorical(logits=logit).sample() for logit in logits ]
return Sampled(res, scores, entropy)

View File

@ -90,6 +90,8 @@ def set_seed(seed=None):
np.random.seed(seed)
torch.manual_seed(seed)
return seed
def _get_named_modules(module, attrname):
for name, module in module.named_modules():
if hasattr(module, attrname):

View File

@ -348,7 +348,7 @@ with ui:
#layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
with gr.Row():
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.9, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row():
#layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")