very, very naive layerskip speculative sampling (it just checks if the current layer's state is good enough)
This commit is contained in:
parent
62fe5b0943
commit
ded746e157
@ -14,10 +14,10 @@ from torch.nn.utils.rnn import pad_sequence
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from time import perf_counter
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -66,6 +66,7 @@ class AR_NAR(Base):
|
|||||||
sampling_dry_base=1.75,
|
sampling_dry_base=1.75,
|
||||||
sampling_dry_allowed_length=2,
|
sampling_dry_allowed_length=2,
|
||||||
sampling_entropix=False,
|
sampling_entropix=False,
|
||||||
|
|
||||||
sampling_layer_skip: bool = False,
|
sampling_layer_skip: bool = False,
|
||||||
sampling_layer_skip_exit_layer: int = -1,
|
sampling_layer_skip_exit_layer: int = -1,
|
||||||
|
|
||||||
@ -281,6 +282,11 @@ class AR_NAR(Base):
|
|||||||
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
|
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
sampling_layer_skip_variables = {} if sampling_layer_skip else None
|
||||||
|
|
||||||
|
if sampling_layer_skip:
|
||||||
|
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer if sampling_layer_skip_exit_layer >= 0 else self.n_layers
|
||||||
|
|
||||||
for i, sequence in enumerate( sequence_list ):
|
for i, sequence in enumerate( sequence_list ):
|
||||||
# add <bos> to text for STT
|
# add <bos> to text for STT
|
||||||
if task_list[i] in text_task:
|
if task_list[i] in text_task:
|
||||||
@ -329,7 +335,7 @@ class AR_NAR(Base):
|
|||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
state=state,
|
state=state,
|
||||||
|
|
||||||
layer_skip_exit_layer=sampling_layer_skip_exit_layer,
|
layer_skip_variables=sampling_layer_skip_variables,
|
||||||
|
|
||||||
output_attentions=sampling_entropix,
|
output_attentions=sampling_entropix,
|
||||||
)
|
)
|
||||||
@ -360,15 +366,11 @@ class AR_NAR(Base):
|
|||||||
|
|
||||||
r = sampled[0]
|
r = sampled[0]
|
||||||
|
|
||||||
if sampled.entropy:
|
if cfg.experimental:
|
||||||
metrics.append( sampled.entropy )
|
if sampled.entropy:
|
||||||
"""
|
metrics.append( sampled.entropy )
|
||||||
elif sampled.confidence:
|
elif sampled.scores:
|
||||||
metrics.append( sampled.confidence )
|
metrics.append( [ { "p": p[0] } for p in sampled.scores ] )
|
||||||
"""
|
|
||||||
elif False:
|
|
||||||
p = [ { "p": torch.nn.functional.softmax(logit[-1, :].cpu(), dim=0)[token.item()].item() } for logit, token in zip(logits, r) ]
|
|
||||||
metrics.append( p )
|
|
||||||
|
|
||||||
if mirostat is not None:
|
if mirostat is not None:
|
||||||
mirostat = sampled.scores
|
mirostat = sampled.scores
|
||||||
@ -402,7 +404,13 @@ class AR_NAR(Base):
|
|||||||
|
|
||||||
if metrics:
|
if metrics:
|
||||||
from ..plot import plot_sample_metrics
|
from ..plot import plot_sample_metrics
|
||||||
plot_sample_metrics( metrics )
|
filename = "metrics"
|
||||||
|
if sampling_entropix:
|
||||||
|
filename += f'[entropix]'
|
||||||
|
if sampling_layer_skip_exit_layer >= 0:
|
||||||
|
filename += f'[{sampling_layer_skip_exit_layer+1}]'
|
||||||
|
|
||||||
|
plot_sample_metrics( metrics, filename=f'{filename}.png' )
|
||||||
|
|
||||||
# pick the best scoring candidate
|
# pick the best scoring candidate
|
||||||
# desu this is always going to be candidate 0
|
# desu this is always going to be candidate 0
|
||||||
|
|||||||
@ -358,7 +358,8 @@ class LlamaModel_Adapted(LlamaModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
exit_layer: Optional[int] = -1,
|
|
||||||
|
layer_skip_lambda = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@ -451,7 +452,9 @@ class LlamaModel_Adapted(LlamaModel):
|
|||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
if 0 <= exit_layer and exit_layer <= l:
|
# check if we should early-exit
|
||||||
|
if layer_skip_lambda and layer_skip_lambda( l, hidden_states ):
|
||||||
|
#_logger.info(f"Early exit at layer: {l}")
|
||||||
break
|
break
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|||||||
@ -38,8 +38,9 @@ from ..emb.qnt import encode_as_embedding
|
|||||||
# yuck, kind of needed
|
# yuck, kind of needed
|
||||||
from ..data import get_task_symmap
|
from ..data import get_task_symmap
|
||||||
|
|
||||||
|
# these seem more elegant than a dict
|
||||||
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states'])
|
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states'])
|
||||||
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more elegant than a dict
|
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy'])
|
||||||
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -476,6 +477,7 @@ class Base(nn.Module):
|
|||||||
self.unified_position_ids = unified_position_ids
|
self.unified_position_ids = unified_position_ids
|
||||||
self.interleave = interleave
|
self.interleave = interleave
|
||||||
self.layerskip = layerskip
|
self.layerskip = layerskip
|
||||||
|
self.special_tasks = [ "len", "stt" ]
|
||||||
|
|
||||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||||
self.langs_emb = None
|
self.langs_emb = None
|
||||||
@ -827,7 +829,7 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
state = None,
|
state = None,
|
||||||
|
|
||||||
layer_skip_exit_layer = -1,
|
layer_skip_lambda = None,
|
||||||
|
|
||||||
output_attentions = False,
|
output_attentions = False,
|
||||||
output_hidden_states = False,
|
output_hidden_states = False,
|
||||||
@ -846,7 +848,7 @@ class Base(nn.Module):
|
|||||||
inputs_embeds=x,
|
inputs_embeds=x,
|
||||||
past_key_values=state,
|
past_key_values=state,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
use_cache=not self.training,
|
use_cache=False, # not self.training,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
@ -855,8 +857,8 @@ class Base(nn.Module):
|
|||||||
if self.n_experts > 1 and self.training:
|
if self.n_experts > 1 and self.training:
|
||||||
kwargs["output_router_logits"] = True
|
kwargs["output_router_logits"] = True
|
||||||
|
|
||||||
if self.layerskip and 0 <= layer_skip_exit_layer and layer_skip_exit_layer < self.n_layers:
|
if self.layerskip and layer_skip_lambda is not None:
|
||||||
kwargs["exit_layer"] = layer_skip_exit_layer
|
kwargs["layer_skip_lambda"] = layer_skip_lambda
|
||||||
|
|
||||||
output = self.model(**kwargs)
|
output = self.model(**kwargs)
|
||||||
x = output["last_hidden_state"]
|
x = output["last_hidden_state"]
|
||||||
@ -938,14 +940,6 @@ class Base(nn.Module):
|
|||||||
# but skip the last state, as it already is normalized
|
# but skip the last state, as it already is normalized
|
||||||
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
|
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
|
||||||
|
|
||||||
# output projection layer with masking
|
|
||||||
if self.classifier is not None:
|
|
||||||
x = self.classifier(x) * mask
|
|
||||||
|
|
||||||
if output.hidden_states:
|
|
||||||
for i, state in enumerate( hidden_states ):
|
|
||||||
hidden_states[i] = self.classifier(hidden_states[i]) * m
|
|
||||||
|
|
||||||
return Logits(x, state, aux_loss, attentions, hidden_states)
|
return Logits(x, state, aux_loss, attentions, hidden_states)
|
||||||
|
|
||||||
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
||||||
@ -965,8 +959,6 @@ class Base(nn.Module):
|
|||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
|
|
||||||
special_tasks = ["stt", "len"]
|
|
||||||
|
|
||||||
inputs = [ [] for _ in range(batch_size) ]
|
inputs = [ [] for _ in range(batch_size) ]
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
quant_level = quant_levels[i] if quant_levels is not None else 0
|
quant_level = quant_levels[i] if quant_levels is not None else 0
|
||||||
@ -981,7 +973,7 @@ class Base(nn.Module):
|
|||||||
# Base-line TTS task
|
# Base-line TTS task
|
||||||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||||
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
||||||
if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks:
|
if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks:
|
||||||
# insert the text prompt
|
# insert the text prompt
|
||||||
if text_list is not None and text_list[i] is not None:
|
if text_list is not None and text_list[i] is not None:
|
||||||
inputs[i].append( ( "text", text_list[i] ) )
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
@ -1259,9 +1251,8 @@ class Base(nn.Module):
|
|||||||
stats = dict(acc = dict())
|
stats = dict(acc = dict())
|
||||||
|
|
||||||
device = logits[0].device
|
device = logits[0].device
|
||||||
special_tasks = [ "len", "stt" ]
|
|
||||||
summed_embeddings_task = [ "stt" ]
|
summed_embeddings_task = [ "stt" ]
|
||||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
|
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||||
|
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_token( input, quant_level ):
|
def prompt_input_to_token( input, quant_level ):
|
||||||
@ -1443,11 +1434,46 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
quant_levels: int | list[int] | Tensor | None = None,
|
quant_levels: int | list[int] | Tensor | None = None,
|
||||||
state: dict | list | None = None,
|
state: dict | list | None = None,
|
||||||
layer_skip_exit_layer: int = -1,
|
|
||||||
|
layer_skip_variables: dict | None = None,
|
||||||
|
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
):
|
):
|
||||||
|
# return early if it's "good" enough"
|
||||||
|
# lambda because we need to capture the classifier_quant_levels and mask
|
||||||
|
def layer_skip_lambda( layer, logits ):
|
||||||
|
kwargs = {
|
||||||
|
"logits_entropy": 0.1,
|
||||||
|
"logits_varentropy": 0.1,
|
||||||
|
"min_layer": self.n_layers // 2,
|
||||||
|
"max_layer": self.n_layers,
|
||||||
|
}
|
||||||
|
|
||||||
|
kwargs.update( layer_skip_variables )
|
||||||
|
|
||||||
|
# don't bother on early layers
|
||||||
|
if layer < kwargs["min_layer"]:
|
||||||
|
return False
|
||||||
|
# bail if we want to force early layers
|
||||||
|
if kwargs["max_layer"] < layer:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# hidden states aren't normalized
|
||||||
|
x = self.model.norm( logits )
|
||||||
|
|
||||||
|
# output projection layer with masking
|
||||||
|
if self.classifier is not None:
|
||||||
|
x = self.classifier(x) * m
|
||||||
|
elif self.classifiers is not None:
|
||||||
|
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
||||||
|
|
||||||
|
# calculate metrics
|
||||||
|
metrics = calculate_entropix_metrics( logits )
|
||||||
|
|
||||||
|
# exit early if "good enough""
|
||||||
|
return metrics["logits_entropy"] < kwargs["logits_entropy"] and metrics["logits_varentropy"] < kwargs["logits_varentropy"]
|
||||||
|
|
||||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||||
x, m = list_to_tensor(x_list)
|
x, m = list_to_tensor(x_list)
|
||||||
|
|
||||||
@ -1459,7 +1485,8 @@ class Base(nn.Module):
|
|||||||
if quant_levels is None:
|
if quant_levels is None:
|
||||||
quant_levels = [ 0 for _ in range(batch_size) ]
|
quant_levels = [ 0 for _ in range(batch_size) ]
|
||||||
|
|
||||||
if self.layerskip:
|
# we only need hidden states if we're training with layerskip
|
||||||
|
if self.layerskip and training:
|
||||||
output_hidden_states = True
|
output_hidden_states = True
|
||||||
|
|
||||||
# pad our input and mask, but retain the original length by doing it after
|
# pad our input and mask, but retain the original length by doing it after
|
||||||
@ -1478,6 +1505,8 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
# needs to be done here as we still have our raw inputs
|
# needs to be done here as we still have our raw inputs
|
||||||
position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
||||||
|
|
||||||
|
classifier_quant_levels = [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||||
|
|
||||||
output = self._forward(
|
output = self._forward(
|
||||||
inputs=x,
|
inputs=x,
|
||||||
@ -1486,17 +1515,22 @@ class Base(nn.Module):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
output_attentions = output_attentions,
|
output_attentions = output_attentions,
|
||||||
output_hidden_states = output_hidden_states,
|
output_hidden_states = output_hidden_states,
|
||||||
layer_skip_exit_layer = layer_skip_exit_layer,
|
layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = output.logits
|
logits = output.logits
|
||||||
hidden_states = output.hidden_states
|
hidden_states = output.hidden_states
|
||||||
|
|
||||||
|
# output projection layer with masking
|
||||||
|
if self.classifier is not None:
|
||||||
|
logits = self.classifier(logits) * m
|
||||||
|
|
||||||
|
if output.hidden_states:
|
||||||
|
for i, state in enumerate( hidden_states ):
|
||||||
|
hidden_states[i] = self.classifier(hidden_states[i]) * m
|
||||||
# to-do: piece-wise classification, now that there's a head for text
|
# to-do: piece-wise classification, now that there's a head for text
|
||||||
# although again, one single monolithic head would be preferable instead......
|
# although again, one single monolithic head would be preferable instead......
|
||||||
if self.classifiers is not None:
|
if self.classifiers is not None:
|
||||||
special_tasks = [ "len", "stt" ]
|
|
||||||
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
|
|
||||||
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
||||||
|
|
||||||
if hidden_states is not None:
|
if hidden_states is not None:
|
||||||
@ -1508,7 +1542,6 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
if hidden_states is not None:
|
if hidden_states is not None:
|
||||||
for i, state in enumerate( hidden_states ):
|
for i, state in enumerate( hidden_states ):
|
||||||
# remove padding
|
|
||||||
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
@ -1573,6 +1606,8 @@ class Base(nn.Module):
|
|||||||
# other
|
# other
|
||||||
attentions=None,
|
attentions=None,
|
||||||
):
|
):
|
||||||
|
batch_size = len( logits )
|
||||||
|
|
||||||
if min_temperature < 0:
|
if min_temperature < 0:
|
||||||
min_temperature = temperature
|
min_temperature = temperature
|
||||||
|
|
||||||
@ -1598,6 +1633,14 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
if res:
|
if res:
|
||||||
return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ])
|
return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ])
|
||||||
|
"""
|
||||||
|
elif quant_levels is None:
|
||||||
|
seq_lens = [ logit.shape[0] for logit in logits ]
|
||||||
|
entropy = [ calculate_entropix_metrics(
|
||||||
|
logit[:seq_lens[batch], :], # ( seq_len, vocab )
|
||||||
|
#attentions[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, seq_len )
|
||||||
|
) for batch, logit in enumerate(logits) ]
|
||||||
|
"""
|
||||||
|
|
||||||
# (NAR) return the entire generated response
|
# (NAR) return the entire generated response
|
||||||
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||||
@ -1666,4 +1709,10 @@ class Base(nn.Module):
|
|||||||
else:
|
else:
|
||||||
res = [ Categorical(logits=logit).sample() for logit in logits ]
|
res = [ Categorical(logits=logit).sample() for logit in logits ]
|
||||||
|
|
||||||
|
# calculate token probabilities
|
||||||
|
scores = [
|
||||||
|
[ F.softmax(logit[-1, :], dim=0)[token].item() for token in tokens ]
|
||||||
|
for logit, tokens in zip(logits, res)
|
||||||
|
]
|
||||||
|
|
||||||
return Sampled(res, scores, entropy)
|
return Sampled(res, scores, entropy)
|
||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -93,7 +94,7 @@ def plot(paths, args):
|
|||||||
#bbox_to_anchor=(1.04, 0.5),
|
#bbox_to_anchor=(1.04, 0.5),
|
||||||
)
|
)
|
||||||
|
|
||||||
def plot_sample_metrics( metrics ):
|
def plot_sample_metrics( metrics, filename=None ):
|
||||||
"""
|
"""
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
fig.set_figwidth( 16 * len(metrics) // cfg.dataset.frames_per_second )
|
fig.set_figwidth( 16 * len(metrics) // cfg.dataset.frames_per_second )
|
||||||
@ -111,7 +112,11 @@ def plot_sample_metrics( metrics ):
|
|||||||
#bbox_to_anchor=(1.04, 0.5),
|
#bbox_to_anchor=(1.04, 0.5),
|
||||||
)
|
)
|
||||||
|
|
||||||
out_path = cfg.rel_path / "metrics.png"
|
if not filename:
|
||||||
|
filename = f'{time.time()}.png'
|
||||||
|
|
||||||
|
out_path = cfg.rel_path / "metrics" / filename
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
plt.savefig(out_path, bbox_inches="tight")
|
plt.savefig(out_path, bbox_inches="tight")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -426,6 +426,7 @@ def sample_entropix(
|
|||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
min_p=0.0,
|
min_p=0.0,
|
||||||
cfg=EntropixSamplerConfig(),
|
cfg=EntropixSamplerConfig(),
|
||||||
|
metrics_only=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
temperature = cfg.temp
|
temperature = cfg.temp
|
||||||
|
|||||||
@ -193,7 +193,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||||||
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
||||||
parser.add_argument("--entropix-sampling", action="store_true")
|
parser.add_argument("--entropix-sampling", action="store_true")
|
||||||
parser.add_argument("--layer-skip", action="store_true")
|
parser.add_argument("--layer-skip", action="store_true")
|
||||||
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"])
|
parser.add_argument("--layer-skip-exit-layer", type=int, default=kwargs["layer-skip-exit-layer"] if cfg.experimental else -1)
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||||
@ -384,7 +384,7 @@ with ui:
|
|||||||
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.5, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
|
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.5, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)")
|
||||||
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
layout["inference_tts"]["inputs"]["layer-skip"] = gr.Checkbox(label="Layer Skip", info="Performs self-speculative early exit 'sampling'")
|
||||||
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||||
with gr.Tab("Sampler Settings"):
|
with gr.Tab("Sampler Settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -411,9 +411,10 @@ with ui:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
|
layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||||
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Model layer to exit early from.")
|
layout["inference_tts"]["inputs"]["layer-skip-exit-layer"] = gr.Slider(value=11, minimum=0, maximum=11, step=1, label="Layer Skip Exit Layer", info="Maximum model layer to exit early from.")
|
||||||
|
|
||||||
|
|
||||||
layout["inference_tts"]["buttons"]["inference"].click(
|
layout["inference_tts"]["buttons"]["inference"].click(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user