made greedy AR sampling viable (and preferable), with caveats (per comment in vall_e.models.ar_nar)
This commit is contained in:
parent
07f4935a75
commit
fc8dfd8617
|
@ -234,7 +234,7 @@ class TTS():
|
||||||
if "nar" in engine.hyper_config.capabilities:
|
if "nar" in engine.hyper_config.capabilities:
|
||||||
model_nar = engine.module
|
model_nar = engine.module
|
||||||
|
|
||||||
set_seed(seed)
|
seed = set_seed(seed)
|
||||||
|
|
||||||
if task == "stt":
|
if task == "stt":
|
||||||
resp = self.encode_audio( references )
|
resp = self.encode_audio( references )
|
||||||
|
|
|
@ -172,7 +172,6 @@ class AR_NAR(Base):
|
||||||
...
|
...
|
||||||
else:
|
else:
|
||||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||||
|
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
@ -268,6 +267,14 @@ class AR_NAR(Base):
|
||||||
scores = [ 1.0 ] * sampling_beam_width
|
scores = [ 1.0 ] * sampling_beam_width
|
||||||
entropies = []
|
entropies = []
|
||||||
|
|
||||||
|
# ick
|
||||||
|
low_temperature = sampling_repetition_penalty == 1.0 and sampling_temperature < 0.5
|
||||||
|
low_temperature_range = cfg.dataset.frames_per_second * 3
|
||||||
|
|
||||||
|
original_sampling_temperature = sampling_temperature
|
||||||
|
original_sampling_repetition_penalty = sampling_repetition_penalty
|
||||||
|
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
|
||||||
|
|
||||||
for i, sequence in enumerate( sequence_list ):
|
for i, sequence in enumerate( sequence_list ):
|
||||||
# add <bos> to text for STT
|
# add <bos> to text for STT
|
||||||
if task_list[i] in text_task:
|
if task_list[i] in text_task:
|
||||||
|
@ -284,6 +291,17 @@ class AR_NAR(Base):
|
||||||
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||||
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||||
|
|
||||||
|
# greedy sampling in the AR *does* work, but requires some quasi-exotic sampling to work around the initial burst of garbage from polluting the rest of the sequence
|
||||||
|
# naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio
|
||||||
|
# however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled
|
||||||
|
# to-do: tune these values, maybe have it factor based on confidence scores or something
|
||||||
|
# to-do: see if instead just prefixing with blank audio overcomes the initla noise anyways
|
||||||
|
if low_temperature:
|
||||||
|
enabled = n < low_temperature_range
|
||||||
|
sampling_repetition_penalty = 1.35 if enabled else original_sampling_repetition_penalty
|
||||||
|
sampling_repetition_penalty_decay = 0.5 if enabled else original_sampling_repetition_penalty_decay
|
||||||
|
sampling_temperature = original_sampling_temperature if enabled else 1.0
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
|
|
|
@ -114,7 +114,7 @@ except Exception as e:
|
||||||
|
|
||||||
# to-do: find a better way to query for if there's available kernels since these return true regardless
|
# to-do: find a better way to query for if there's available kernels since these return true regardless
|
||||||
if torch.backends.cuda.flash_sdp_enabled():
|
if torch.backends.cuda.flash_sdp_enabled():
|
||||||
AVAILABLE_ATTENTIONS.append("flash")
|
AVAILABLE_ATTENTIONS.append("flash_(sdpa)")
|
||||||
|
|
||||||
if torch.backends.cuda.mem_efficient_sdp_enabled():
|
if torch.backends.cuda.mem_efficient_sdp_enabled():
|
||||||
AVAILABLE_ATTENTIONS.append("mem_efficient")
|
AVAILABLE_ATTENTIONS.append("mem_efficient")
|
||||||
|
@ -141,7 +141,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
self.mode = torch.nn.attention.SDPBackend.MATH
|
self.mode = torch.nn.attention.SDPBackend.MATH
|
||||||
elif self.mode == "mem_efficient":
|
elif self.mode == "mem_efficient":
|
||||||
self.mode = torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
|
self.mode = torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
|
||||||
elif self.mode == "flash":
|
elif self.mode == "flash_(sdpa)":
|
||||||
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
|
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
|
||||||
elif self.mode == "cudnn":
|
elif self.mode == "cudnn":
|
||||||
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||||
|
|
|
@ -1535,12 +1535,6 @@ class Base(nn.Module):
|
||||||
if quant_levels is None and "len" in self.capabilities:
|
if quant_levels is None and "len" in self.capabilities:
|
||||||
logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, prev_list) ) ]
|
logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||||
|
|
||||||
# argmax instead
|
|
||||||
if temperature <= 0.0:
|
|
||||||
res = [ logit.argmax(dim=1) for logit in logits ]
|
|
||||||
scores = None
|
|
||||||
return Sampled(res, scores, entropy)
|
|
||||||
|
|
||||||
# perform repetition penalizing
|
# perform repetition penalizing
|
||||||
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
|
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
|
||||||
# to-do: figure out a faster way to handle tolist()
|
# to-do: figure out a faster way to handle tolist()
|
||||||
|
@ -1562,7 +1556,7 @@ class Base(nn.Module):
|
||||||
# epsilon float comparison because I don't trust Python
|
# epsilon float comparison because I don't trust Python
|
||||||
if abs(temperature - min_temperature) >= 0.001:
|
if abs(temperature - min_temperature) >= 0.001:
|
||||||
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
|
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
|
||||||
else:
|
elif temperature > 0.0:
|
||||||
logits = [ logit / temperature for logit in logits ]
|
logits = [ logit / temperature for logit in logits ]
|
||||||
|
|
||||||
# do DRY sampling
|
# do DRY sampling
|
||||||
|
@ -1585,6 +1579,10 @@ class Base(nn.Module):
|
||||||
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
|
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
|
||||||
# basic sampling
|
# basic sampling
|
||||||
else:
|
else:
|
||||||
res = [ Categorical(logits=logit).sample() for logit in logits ]
|
# argmax instead
|
||||||
|
if temperature <= 0.0:
|
||||||
|
res = [ logit.argmax(dim=1) for logit in logits ]
|
||||||
|
else:
|
||||||
|
res = [ Categorical(logits=logit).sample() for logit in logits ]
|
||||||
|
|
||||||
return Sampled(res, scores, entropy)
|
return Sampled(res, scores, entropy)
|
|
@ -90,6 +90,8 @@ def set_seed(seed=None):
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
return seed
|
||||||
|
|
||||||
def _get_named_modules(module, attrname):
|
def _get_named_modules(module, attrname):
|
||||||
for name, module in module.named_modules():
|
for name, module in module.named_modules():
|
||||||
if hasattr(module, attrname):
|
if hasattr(module, attrname):
|
||||||
|
|
|
@ -348,7 +348,7 @@ with ui:
|
||||||
#layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
#layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||||
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
|
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.9, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
|
||||||
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
#layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
#layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user