implemented a naive beam search (I really should be taking a break)
This commit is contained in:
parent
a6ae344e5b
commit
23a5fdd645
|
@ -139,8 +139,8 @@ And some experimental sampling flags you can use too (your mileage will ***defin
|
|||
* `--top-k`: limits the sampling pool to the top `K` values in the probability distribution.
|
||||
* `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use.
|
||||
* `--repetition-penalty-decay`: modifies the above factor applied to scale based on how far away it is in the past sequence.
|
||||
* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky.
|
||||
|
||||
* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky due to the AR already being well correlated with the length.
|
||||
* `--beam-width`: (AR only) specifies the number of branches to search through for beam sampling. This is a very naive implementation that's effectively just greedy sampling across `B` spaces.
|
||||
## To-Do
|
||||
|
||||
* reduce load time for creating / preparing dataloaders (hint: remove use of `Path.glob` and `Path.rglob`).
|
||||
|
|
|
@ -27,6 +27,7 @@ def main():
|
|||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--beam-width", type=int, default=0)
|
||||
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
|
@ -34,7 +35,7 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||
tts.inference( text=args.text, references=args.references, out_path=args.out_path, input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty )
|
||||
tts.inference( text=args.text, references=args.references, out_path=args.out_path, input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, beam_width=args.beam_width )
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -139,7 +139,23 @@ class TTS():
|
|||
return res
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference( self, text, references, max_ar_steps=6 * 75, max_nar_levels=7, input_prompt_length=0.0, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, repetition_penalty_decay=0.0, length_penalty=0.0, out_path=None ):
|
||||
def inference(
|
||||
self,
|
||||
text,
|
||||
references,
|
||||
max_ar_steps=6 * 75,
|
||||
max_nar_levels=7,
|
||||
input_prompt_length=0.0,
|
||||
ar_temp=0.95,
|
||||
nar_temp=0.5,
|
||||
top_p=1.0,
|
||||
top_k=0,
|
||||
repetition_penalty=1.0,
|
||||
repetition_penalty_decay=0.0,
|
||||
length_penalty=0.0,
|
||||
beam_width=0,
|
||||
out_path=None
|
||||
):
|
||||
if out_path is None:
|
||||
out_path = f"./data/{cfg.start_time}.wav"
|
||||
|
||||
|
@ -150,9 +166,9 @@ class TTS():
|
|||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty)
|
||||
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty, sampling_beam_width=beam_width)
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, max_levels=max_nar_levels, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty)
|
||||
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, max_levels=max_nar_levels, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty, sampling_beam_width=beam_width)
|
||||
|
||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
|
||||
|
||||
|
|
|
@ -91,12 +91,14 @@ class AR(Base):
|
|||
proms_list: list[Tensor],
|
||||
resps_list: list[Tensor] | None = None,
|
||||
max_steps: int = 1000,
|
||||
|
||||
sampling_temperature: float = 1.0,
|
||||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
sampling_beam_width: int = 0,
|
||||
):
|
||||
if resps_list is not None:
|
||||
if self.interleave:
|
||||
|
@ -126,24 +128,39 @@ class AR(Base):
|
|||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||
# get next in sequence
|
||||
|
||||
r = super().forward(
|
||||
logits = super().forward(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=self._unsqueeze_list(resps_list),
|
||||
quant_levels=None,
|
||||
sampling_temperature=sampling_temperature,
|
||||
sampling_top_p=sampling_top_p,
|
||||
sampling_top_k=sampling_top_k,
|
||||
sampling_repetition_penalty=sampling_repetition_penalty,
|
||||
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
sampling_length_penalty=sampling_length_penalty,
|
||||
state=state
|
||||
)
|
||||
|
||||
r = super().sample(
|
||||
logits=logits,
|
||||
resps_list=resps_list,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
length_penalty=sampling_length_penalty,
|
||||
beam_width=sampling_beam_width,
|
||||
)
|
||||
|
||||
# first step, expand batch
|
||||
# we do it here because the sampler will already expand our logits list
|
||||
if sampling_beam_width > 0 and batch_size == 1:
|
||||
text_list = text_list * sampling_beam_width
|
||||
proms_list = proms_list * sampling_beam_width
|
||||
resps_list = resps_list * sampling_beam_width
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
if self.stop_token in ri:
|
||||
stopped[i] = True
|
||||
|
||||
resps_list[i] = torch.cat([resps_list[i], ri])
|
||||
|
||||
# stop token found
|
||||
|
|
|
@ -83,6 +83,7 @@ class AR_NAR(Base):
|
|||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
sampling_beam_width: int = 0,
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
@ -119,28 +120,33 @@ class AR_NAR(Base):
|
|||
|
||||
quant_levels = torch.full((len(text_list),), level, device=device)
|
||||
|
||||
resps_list = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
prev_list,
|
||||
logits = super().forward(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
sampling_temperature=sampling_temperature,
|
||||
sampling_top_p=sampling_top_p,
|
||||
sampling_top_k=sampling_top_k,
|
||||
sampling_repetition_penalty=sampling_repetition_penalty,
|
||||
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
sampling_length_penalty=sampling_length_penalty,
|
||||
)
|
||||
|
||||
prev_list = [
|
||||
torch.cat([rs, r.unsqueeze(-1)], dim=-1)
|
||||
for rs, r in zip(prev_list, resps_list)
|
||||
]
|
||||
resps_list = super().sample(
|
||||
logits=logits,
|
||||
resps_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#length_penalty=sampling_length_penalty,
|
||||
#beam_width=sampling_beam_width,
|
||||
)
|
||||
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
||||
return prev_list
|
||||
|
||||
# is AR
|
||||
resps_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
state = {} if cfg.inference.recurrent_forward else None
|
||||
|
@ -151,31 +157,53 @@ class AR_NAR(Base):
|
|||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||
# get next in sequence
|
||||
|
||||
r = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
self._unsqueeze_list(resps_list),
|
||||
sampling_temperature=sampling_temperature,
|
||||
sampling_top_p=sampling_top_p,
|
||||
sampling_top_k=sampling_top_k,
|
||||
sampling_repetition_penalty=sampling_repetition_penalty,
|
||||
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
sampling_length_penalty=sampling_length_penalty,
|
||||
resps_list = self._unsqueeze_list(sequence_list)
|
||||
logits = super().forward(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
state=state
|
||||
)
|
||||
|
||||
r = super().sample(
|
||||
logits=logits,
|
||||
resps_list=resps_list,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
length_penalty=sampling_length_penalty,
|
||||
beam_width=sampling_beam_width,
|
||||
)
|
||||
|
||||
# first step, expand batch
|
||||
# we do it here because the sampler will already expand our logits list
|
||||
if sampling_beam_width > 0 and batch_size == 1:
|
||||
batch_size *= sampling_beam_width
|
||||
text_list = text_list * sampling_beam_width
|
||||
proms_list = proms_list * sampling_beam_width
|
||||
sequence_list = sequence_list * sampling_beam_width
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
if self.stop_token in ri:
|
||||
stopped[i] = True
|
||||
resps_list[i] = torch.cat([resps_list[i], ri])
|
||||
sequence_list[i] = torch.cat([sequence_list[i], ri])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == self.stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
|
||||
return [self._prune(r) for r in resps_list]
|
||||
# pick the first candidate
|
||||
if sampling_beam_width:
|
||||
sequence_list = sequence_list[:1]
|
||||
|
||||
return [self._prune(r) for r in sequence_list]
|
||||
|
||||
|
||||
def example_usage():
|
||||
|
@ -200,11 +228,9 @@ def example_usage():
|
|||
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device)
|
||||
|
||||
text_list = [
|
||||
#torch.tensor([1, 2, 3], device=device),
|
||||
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
|
||||
]
|
||||
proms_list = [
|
||||
#x8(torch.tensor([1, 2, 3], device=device)),
|
||||
qnt[:75*3, :].to(device),
|
||||
]
|
||||
resps_list = [
|
||||
|
@ -232,7 +258,7 @@ def example_usage():
|
|||
model = AR_NAR(**kwargs).to(device)
|
||||
#steps = 500
|
||||
#optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||
steps = 500
|
||||
steps = 1000
|
||||
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
|
@ -241,7 +267,7 @@ def example_usage():
|
|||
@torch.inference_mode()
|
||||
def sample( name, steps=600 ):
|
||||
engine.eval()
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95, sampling_beam_width=16 )
|
||||
|
||||
for i, o in enumerate(resps_list):
|
||||
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
|
||||
|
|
|
@ -2,6 +2,7 @@ import math
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import traceback
|
||||
import numpy as np
|
||||
|
||||
from typing import Literal, overload
|
||||
from functools import partial
|
||||
|
@ -53,7 +54,7 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
|
|||
# `one_time` will only apply the penalty once
|
||||
# `decay` is a factor that will exponentially apply to how far away it is
|
||||
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ):
|
||||
if factor == 1.0:
|
||||
if factor == 1.0 or previous is None:
|
||||
return logits
|
||||
|
||||
unique = set()
|
||||
|
@ -115,6 +116,7 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"
|
|||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
|
@ -341,13 +343,6 @@ class Base(nn.Module):
|
|||
targ_list: list[Tensor] | None = None,
|
||||
|
||||
quant_levels: Tensor | None = None,
|
||||
sampling_temperature: float = 1.0,
|
||||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
|
||||
state: dict | None = None,
|
||||
):
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
|
@ -428,8 +423,25 @@ class Base(nn.Module):
|
|||
precision = self.precision_metric( inputs, target ),
|
||||
)
|
||||
|
||||
return logits
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: list[Tensor],
|
||||
resps_list: list[Tensor],
|
||||
quant_levels: Tensor | None = None,
|
||||
|
||||
temperature: float = 1.0,
|
||||
top_k: int = -100,
|
||||
top_p: float = 1.0,
|
||||
|
||||
repetition_penalty: float = 1.0,
|
||||
repetition_penalty_decay: float = 0.0,
|
||||
|
||||
length_penalty: float = 0.0,
|
||||
|
||||
beam_width: int = 0,
|
||||
):
|
||||
# (NAR) return the entire generated response
|
||||
if quant_levels is not None:
|
||||
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
||||
|
@ -441,19 +453,37 @@ class Base(nn.Module):
|
|||
logits = [ logit[-1:] for logit in logits ]
|
||||
|
||||
# perform repetition penalizing
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty, decay=sampling_repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||
|
||||
# (AR) perform length penalizing
|
||||
if quant_levels is None and self.causal:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=sampling_length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||
|
||||
# scale our logits by the temp
|
||||
logits = [ logit / sampling_temperature for logit in logits ]
|
||||
logits = [ logit / temperature for logit in logits ]
|
||||
|
||||
# perform top_k/top_p filtering of our logits
|
||||
if sampling_top_k > 0:
|
||||
logits = [ top_k_top_p_filtering(logit, top_k=sampling_top_k, top_p=sampling_top_p) for logit in logits ]
|
||||
|
||||
if top_k > 0 or top_p < 1.0:
|
||||
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
||||
|
||||
# do beam search (naive implementation)
|
||||
# picks the top-k across all batches, and re-batches those resultant tokens
|
||||
# this doesn't do any other mumbo with previous logits
|
||||
# to-do: not naively implement beam searching
|
||||
if beam_width > 1:
|
||||
# ( batch, tokens ) => ( batch x tokens )
|
||||
flattened = torch.cat( logits )
|
||||
candidates = list(torch.topk(flattened.flatten(), beam_width).indices.tolist()) # perform top-k across all logits
|
||||
for i, index in enumerate(candidates):
|
||||
t = []
|
||||
N = np.prod(flattened.size())
|
||||
for n in flattened.size():
|
||||
N //= n
|
||||
t.append(index // N)
|
||||
index %= N
|
||||
candidates[i] = tuple(t)
|
||||
return [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ] #, [ logits[batch] for batch, token in candidates ]
|
||||
|
||||
# and sample
|
||||
# the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax
|
||||
return [ Categorical(logits=logit).sample() for logit in logits ]
|
||||
|
|
|
@ -98,11 +98,11 @@ class NAR(Base):
|
|||
|
||||
quant_levels = quant_levels.to(device=device)
|
||||
|
||||
_ = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
prev_list,
|
||||
targ_list,
|
||||
logits = super().forward(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=prev_list,
|
||||
targ_list=targ_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
|
@ -120,23 +120,28 @@ class NAR(Base):
|
|||
|
||||
quant_levels = torch.full((len(text_list),), level, device=device)
|
||||
|
||||
resps_list = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
prev_list,
|
||||
logits = super().forward(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
sampling_temperature=sampling_temperature,
|
||||
sampling_top_p=sampling_top_p,
|
||||
sampling_top_k=sampling_top_k,
|
||||
sampling_repetition_penalty=sampling_repetition_penalty,
|
||||
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
sampling_length_penalty=sampling_length_penalty,
|
||||
)
|
||||
|
||||
prev_list = [
|
||||
torch.cat([rs, r.unsqueeze(-1)], dim=-1)
|
||||
for rs, r in zip(prev_list, resps_list)
|
||||
]
|
||||
resps_list = super().sample(
|
||||
logits=logits,
|
||||
resps_list=resps_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#length_penalty=sampling_length_penalty,
|
||||
#beam_width=sampling_beam_width,
|
||||
)
|
||||
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
||||
return prev_list
|
||||
|
||||
|
|
|
@ -46,10 +46,6 @@ def train_feeder(engine, batch):
|
|||
|
||||
@torch.inference_mode()
|
||||
def run_eval(engines, disabled_engines, eval_name, dl):
|
||||
engines_stats = {
|
||||
'eval': eval_name
|
||||
}
|
||||
|
||||
AR = None
|
||||
NAR = None
|
||||
AR_NAR = None
|
||||
|
@ -156,7 +152,7 @@ def run_eval(engines, disabled_engines, eval_name, dl):
|
|||
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||
engines_stats.update(flatten_dict({ name: stats }))
|
||||
engines_stats.update({ f'{name}.{eval_name}': stats })
|
||||
|
||||
iteration = engines.global_step
|
||||
engines_stats['it'] = iteration
|
||||
|
|
|
@ -65,6 +65,7 @@ def init_tts(restart=False):
|
|||
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
||||
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
# I'm very sure I can procedurally generate this list
|
||||
parser.add_argument("--text", type=str, default=kwargs["text"])
|
||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||
|
@ -77,6 +78,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
||||
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
||||
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||
|
@ -181,7 +183,7 @@ with ui:
|
|||
with gr.Column(scale=7):
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=3, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
||||
|
@ -190,6 +192,7 @@ with ui:
|
|||
with gr.Row():
|
||||
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info="Limits the samples that are outside the top P%% of probabilities.")
|
||||
layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
||||
layout["inference"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||
|
|
Loading…
Reference in New Issue
Block a user