modified demo page to be more modular with demoing comparisons, actually provide a path to use modified naive attention, entropix sampling is not tied to an experimental yaml flag now
This commit is contained in:
parent
666e8038fb
commit
04e983b86b
|
@ -43,6 +43,8 @@ def main():
|
||||||
parser.add_argument("--dry-base", type=float, default=1.75)
|
parser.add_argument("--dry-base", type=float, default=1.75)
|
||||||
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
||||||
|
|
||||||
|
parser.add_argument("--entropix-sampling", action="store_true")
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=None)
|
parser.add_argument("--seed", type=int, default=None)
|
||||||
|
|
||||||
parser.add_argument("--device", type=str, default=None)
|
parser.add_argument("--device", type=str, default=None)
|
||||||
|
@ -69,6 +71,7 @@ def main():
|
||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||||
|
entropix_sampling=args.entropix_sampling,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -237,8 +237,6 @@ class ModelExperimentalSettings:
|
||||||
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
|
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
|
||||||
# to-to: just incorporate this as a task instead
|
# to-to: just incorporate this as a task instead
|
||||||
|
|
||||||
entropix_sampling: bool = False # experimental sampling based on https://github.com/xjdr-alt/entropix, experimental flag because it requires using naive attention for output scores
|
|
||||||
|
|
||||||
# I really need to clean this up
|
# I really need to clean this up
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Model:
|
class Model:
|
||||||
|
|
|
@ -76,6 +76,8 @@ def main():
|
||||||
parser.add_argument("--dry-base", type=float, default=1.75)
|
parser.add_argument("--dry-base", type=float, default=1.75)
|
||||||
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
||||||
|
|
||||||
|
parser.add_argument("--entropix-sampling", action="store_true")
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=None)
|
parser.add_argument("--seed", type=int, default=None)
|
||||||
|
|
||||||
parser.add_argument("--device", type=str, default=None)
|
parser.add_argument("--device", type=str, default=None)
|
||||||
|
@ -84,6 +86,7 @@ def main():
|
||||||
|
|
||||||
parser.add_argument("--random-prompts", action="store_true")
|
parser.add_argument("--random-prompts", action="store_true")
|
||||||
parser.add_argument("--lora", action="store_true")
|
parser.add_argument("--lora", action="store_true")
|
||||||
|
parser.add_argument("--comparison", action="store_true")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -98,6 +101,30 @@ def main():
|
||||||
'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.',
|
'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.',
|
||||||
])
|
])
|
||||||
|
|
||||||
|
# comparison kwargs
|
||||||
|
comparison_kwargs = {
|
||||||
|
"enabled": False,
|
||||||
|
"titles": [],
|
||||||
|
"suffix": "_after",
|
||||||
|
"before": {},
|
||||||
|
"after": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.lora:
|
||||||
|
comparison_kwargs["enabled"] = True
|
||||||
|
comparison_kwargs["suffix"] = "_lora"
|
||||||
|
comparison_kwargs["titles"] = ["No LoRA", "LoRA"]
|
||||||
|
comparison_kwargs["before"]["use_lora"] = True
|
||||||
|
comparison_kwargs["after"]["use_lora"] = False
|
||||||
|
# to-do: make this user definable
|
||||||
|
elif args.comparison:
|
||||||
|
comparison_kwargs["enabled"] = True
|
||||||
|
comparison_kwargs["suffix"] = "_entropix"
|
||||||
|
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||||
|
comparison_kwargs["before"]["entropix_sampling"] = True
|
||||||
|
comparison_kwargs["after"]["entropix_sampling"] = False
|
||||||
|
|
||||||
|
|
||||||
# read html template
|
# read html template
|
||||||
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
||||||
|
|
||||||
|
@ -114,6 +141,7 @@ def main():
|
||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||||
|
entropix_sampling=args.entropix_sampling,
|
||||||
)) )
|
)) )
|
||||||
|
|
||||||
# pull from provided samples
|
# pull from provided samples
|
||||||
|
@ -127,7 +155,7 @@ def main():
|
||||||
# pull from dataset samples
|
# pull from dataset samples
|
||||||
if args.sample_from_dataset:
|
if args.sample_from_dataset:
|
||||||
cfg.dataset.cache = False
|
cfg.dataset.cache = False
|
||||||
cfg.dataset.sample_type = "path" if args.lora else "speaker"
|
cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker"
|
||||||
cfg.dataset.tasks_list = [ 'tts' ]
|
cfg.dataset.tasks_list = [ 'tts' ]
|
||||||
|
|
||||||
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
|
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
|
||||||
|
@ -180,9 +208,9 @@ def main():
|
||||||
prompt = dir / "prompt.wav"
|
prompt = dir / "prompt.wav"
|
||||||
reference = dir / "reference.wav"
|
reference = dir / "reference.wav"
|
||||||
out_path = dir / "out" / "ours.wav"
|
out_path = dir / "out" / "ours.wav"
|
||||||
out_path_lora = dir / "out" / "ours_lora.wav"
|
out_path_comparison = dir / "out" / f"ours_{comparison_kwargs["suffix"]}.wav"
|
||||||
|
|
||||||
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
|
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_comparison ] if comparison_kwargs["enabled"] else [])
|
||||||
|
|
||||||
if not args.random_prompts or k == "librispeech":
|
if not args.random_prompts or k == "librispeech":
|
||||||
extra_sources += [ reference ]
|
extra_sources += [ reference ]
|
||||||
|
@ -210,23 +238,24 @@ def main():
|
||||||
length_penalty=args.length_penalty,
|
length_penalty=args.length_penalty,
|
||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||||
|
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||||
|
entropix_sampling=args.entropix_sampling,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
tqdm=False,
|
tqdm=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.lora:
|
def safe_inference():
|
||||||
tts.enable_lora() # I don't think this is necessary with the below
|
|
||||||
kwargs["use_lora"] = True
|
|
||||||
try:
|
try:
|
||||||
tts.inference( out_path=out_path_lora, **kwargs )
|
tts.inference( out_path=out_path_comparison, **kwargs )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Error while processing {out_path}: {e}')
|
print(f'Error while processing {out_path}: {e}')
|
||||||
tts.disable_lora()
|
|
||||||
kwargs["use_lora"] = False
|
if comparison_kwargs["enabled"]:
|
||||||
try:
|
kwargs.update( comparison_kwargs["before"] )
|
||||||
tts.inference( out_path=out_path, **kwargs )
|
safe_inference()
|
||||||
except Exception as e:
|
kwargs.update( comparison_kwargs["after"] )
|
||||||
print(f'Error while processing {out_path}: {e}')
|
|
||||||
|
safe_inference()
|
||||||
|
|
||||||
|
|
||||||
# collate entries into HTML
|
# collate entries into HTML
|
||||||
|
@ -243,11 +272,12 @@ def main():
|
||||||
# write audio into template
|
# write audio into template
|
||||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||||
|
|
||||||
if args.lora:
|
if comparison_kwargs["enabled"]:
|
||||||
|
before, after = comparison_kwargs["titles"]
|
||||||
if args.random_prompts:
|
if args.random_prompts:
|
||||||
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<th>Our VALL-E (LoRA)</th>")
|
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
|
||||||
else:
|
else:
|
||||||
html = html.replace("<th>Our VALL-E</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<th>Our VALL-E (LoRA)</th>")
|
html = html.replace("<th>Our VALL-E</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
|
||||||
|
|
||||||
# write demo page
|
# write demo page
|
||||||
open( args.demo_dir / args.output_filename, "w", encoding="utf-8" ).write( html )
|
open( args.demo_dir / args.output_filename, "w", encoding="utf-8" ).write( html )
|
||||||
|
|
|
@ -206,7 +206,9 @@ class TTS():
|
||||||
dry_multiplier=0.0,
|
dry_multiplier=0.0,
|
||||||
dry_base=1.75,
|
dry_base=1.75,
|
||||||
dry_allowed_length=2,
|
dry_allowed_length=2,
|
||||||
|
#
|
||||||
|
entropix_sampling=False,
|
||||||
|
#
|
||||||
seed = None,
|
seed = None,
|
||||||
|
|
||||||
out_path=None,
|
out_path=None,
|
||||||
|
@ -255,6 +257,7 @@ class TTS():
|
||||||
sampling_dry_multiplier=dry_multiplier,
|
sampling_dry_multiplier=dry_multiplier,
|
||||||
sampling_dry_base=dry_base,
|
sampling_dry_base=dry_base,
|
||||||
sampling_dry_allowed_length=dry_allowed_length,
|
sampling_dry_allowed_length=dry_allowed_length,
|
||||||
|
sampling_entropix=entropix_sampling,
|
||||||
|
|
||||||
disable_tqdm=not tqdm,
|
disable_tqdm=not tqdm,
|
||||||
use_lora=use_lora,
|
use_lora=use_lora,
|
||||||
|
@ -299,6 +302,7 @@ class TTS():
|
||||||
sampling_dry_multiplier=dry_multiplier,
|
sampling_dry_multiplier=dry_multiplier,
|
||||||
sampling_dry_base=dry_base,
|
sampling_dry_base=dry_base,
|
||||||
sampling_dry_allowed_length=dry_allowed_length,
|
sampling_dry_allowed_length=dry_allowed_length,
|
||||||
|
sampling_entropix=entropix_sampling,
|
||||||
|
|
||||||
disable_tqdm=not tqdm,
|
disable_tqdm=not tqdm,
|
||||||
use_lora=use_lora,
|
use_lora=use_lora,
|
||||||
|
|
|
@ -64,8 +64,7 @@ class AR_NAR(Base):
|
||||||
sampling_dry_multiplier=0.0,
|
sampling_dry_multiplier=0.0,
|
||||||
sampling_dry_base=1.75,
|
sampling_dry_base=1.75,
|
||||||
sampling_dry_allowed_length=2,
|
sampling_dry_allowed_length=2,
|
||||||
|
sampling_entropix=False,
|
||||||
sampling_entropix=None,
|
|
||||||
|
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
use_lora=None,
|
use_lora=None,
|
||||||
|
@ -269,9 +268,6 @@ class AR_NAR(Base):
|
||||||
scores = [ 1.0 ] * sampling_beam_width
|
scores = [ 1.0 ] * sampling_beam_width
|
||||||
entropies = []
|
entropies = []
|
||||||
|
|
||||||
if sampling_entropix is None:
|
|
||||||
sampling_entropix = self.config.experimental.entropix_sampling
|
|
||||||
|
|
||||||
for i, sequence in enumerate( sequence_list ):
|
for i, sequence in enumerate( sequence_list ):
|
||||||
# add <bos> to text for STT
|
# add <bos> to text for STT
|
||||||
if task_list[i] in text_task:
|
if task_list[i] in text_task:
|
||||||
|
|
|
@ -126,7 +126,8 @@ if torch.backends.cuda.cudnn_sdp_enabled():
|
||||||
AVAILABLE_ATTENTIONS.append("cudnn")
|
AVAILABLE_ATTENTIONS.append("cudnn")
|
||||||
|
|
||||||
if AVAILABLE_ATTENTIONS:
|
if AVAILABLE_ATTENTIONS:
|
||||||
AVAILABLE_ATTENTIONS.append("sdpa")
|
AVAILABLE_ATTENTIONS.append("sdpa")
|
||||||
|
AVAILABLE_ATTENTIONS.append("default")
|
||||||
|
|
||||||
class LlamaAttention_Adapted(LlamaAttention):
|
class LlamaAttention_Adapted(LlamaAttention):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -144,6 +145,8 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
|
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
|
||||||
elif self.mode == "cudnn":
|
elif self.mode == "cudnn":
|
||||||
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||||
|
else:
|
||||||
|
self.mode = None
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -256,7 +259,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions or not self.mode:
|
||||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
return self._forward(
|
return self._forward(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
|
|
@ -531,7 +531,7 @@ class Base(nn.Module):
|
||||||
if AVAILABLE_ATTENTIONS:
|
if AVAILABLE_ATTENTIONS:
|
||||||
attention_backend = AVAILABLE_ATTENTIONS[0]
|
attention_backend = AVAILABLE_ATTENTIONS[0]
|
||||||
else:
|
else:
|
||||||
attention_backend = "eager"
|
attention_backend = "default"
|
||||||
|
|
||||||
hf_attention = attention_backend
|
hf_attention = attention_backend
|
||||||
HF_ATTENTIONS = ["eager", "sdpa", "flash_attention_2"]
|
HF_ATTENTIONS = ["eager", "sdpa", "flash_attention_2"]
|
||||||
|
@ -1530,13 +1530,14 @@ class Base(nn.Module):
|
||||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||||
|
|
||||||
# (AR) entropix sampling
|
# (AR) entropix sampling
|
||||||
|
# we do it after the penalizers because entropix's internal sampling doesn't account for them (but does do top_k/top_p/min_p)
|
||||||
if attentions is not None and quant_levels is None:
|
if attentions is not None and quant_levels is None:
|
||||||
# move to CPU for speedups
|
# move to CPU for speedups
|
||||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||||
|
|
||||||
res = [ sample_entropix(
|
res = [ sample_entropix(
|
||||||
logit,
|
logit,
|
||||||
attentions[-1], #torch.stack(attentions, dim=1),
|
attentions[-1], # original code just uses the last attention scores
|
||||||
temperature,
|
temperature,
|
||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
|
|
|
@ -407,7 +407,7 @@ def sample_entropix(
|
||||||
else:
|
else:
|
||||||
metrics["action"] = 4
|
metrics["action"] = 4
|
||||||
|
|
||||||
log_softmax = torch.nn.functional.log_softmax(logits)
|
log_softmax = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
logits_uncertainty = ent + vent
|
logits_uncertainty = ent + vent
|
||||||
attn_uncertainty = attn_ent + attn_vent
|
attn_uncertainty = attn_ent + attn_vent
|
||||||
|
|
||||||
|
|
|
@ -167,6 +167,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
|
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
|
||||||
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
|
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
|
||||||
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
||||||
|
parser.add_argument("--entropix-sampling", action="store_true")
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||||
|
@ -176,6 +177,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
raise Exception("No reference audio provided.")
|
raise Exception("No reference audio provided.")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if kwargs.pop("entropix-sampling", False):
|
||||||
|
args.entropix_sampling = True
|
||||||
|
|
||||||
tts = init_tts()
|
tts = init_tts()
|
||||||
|
|
||||||
gr.Info("Inferencing...")
|
gr.Info("Inferencing...")
|
||||||
|
@ -206,6 +210,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
dry_multiplier=args.dry_multiplier,
|
dry_multiplier=args.dry_multiplier,
|
||||||
dry_base=args.dry_base,
|
dry_base=args.dry_base,
|
||||||
dry_allowed_length=args.dry_allowed_length,
|
dry_allowed_length=args.dry_allowed_length,
|
||||||
|
entropix_sampling=args.entropix_sampling
|
||||||
)
|
)
|
||||||
|
|
||||||
wav = wav.squeeze(0).cpu().numpy()
|
wav = wav.squeeze(0).cpu().numpy()
|
||||||
|
@ -240,8 +245,10 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
|
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
|
||||||
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
|
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
|
||||||
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
||||||
|
parser.add_argument("--entropix-sampling", action="store_true")
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not args.references:
|
if not args.references:
|
||||||
raise Exception("No reference audio provided.")
|
raise Exception("No reference audio provided.")
|
||||||
|
@ -254,6 +261,9 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
duration = metadata.num_frames / metadata.sample_rate
|
duration = metadata.num_frames / metadata.sample_rate
|
||||||
args.max_ar_steps += duration
|
args.max_ar_steps += duration
|
||||||
args.max_ar_steps = math.floor( args.max_ar_steps * 20 ) # assume 20 tokens per second
|
args.max_ar_steps = math.floor( args.max_ar_steps * 20 ) # assume 20 tokens per second
|
||||||
|
|
||||||
|
if kwargs.pop("entropix-sampling", False):
|
||||||
|
args.entropix_sampling = True
|
||||||
|
|
||||||
tts = init_tts()
|
tts = init_tts()
|
||||||
|
|
||||||
|
@ -278,6 +288,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
dry_multiplier=args.dry_multiplier,
|
dry_multiplier=args.dry_multiplier,
|
||||||
dry_base=args.dry_base,
|
dry_base=args.dry_base,
|
||||||
dry_allowed_length=args.dry_allowed_length,
|
dry_allowed_length=args.dry_allowed_length,
|
||||||
|
entropix_sampling=args.entropix_sampling,
|
||||||
)
|
)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
@ -342,6 +353,7 @@ with ui:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
#layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
#layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
||||||
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||||
|
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||||
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||||
with gr.Tab("Sampler Settings"):
|
with gr.Tab("Sampler Settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -382,6 +394,7 @@ with ui:
|
||||||
layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||||
|
layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||||
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||||
with gr.Tab("Sampler Settings"):
|
with gr.Tab("Sampler Settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user