added mirostat sampling (given a partially trained model, it got far decent output than I expected, need to test on a better trained model)
This commit is contained in:
parent
2567e082b5
commit
a6bfe43590
10
README.md
10
README.md
|
@ -140,13 +140,19 @@ And some experimental sampling flags you can use too (your mileage will ***defin
|
||||||
* `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use.
|
* `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use.
|
||||||
* `--repetition-penalty-decay`: modifies the above factor applied to scale based on how far away it is in the past sequence.
|
* `--repetition-penalty-decay`: modifies the above factor applied to scale based on how far away it is in the past sequence.
|
||||||
* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky due to the AR already being well correlated with the length.
|
* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky due to the AR already being well correlated with the length.
|
||||||
* `--beam-width`: (AR only) specifies the number of branches to search through for beam sampling. This is a very naive implementation that's effectively just greedy sampling across `B` spaces.
|
* `--beam-width`: (AR only) specifies the number of branches to search through for beam sampling.
|
||||||
|
+ This is a very naive implementation that's effectively just greedy sampling across `B` spaces.
|
||||||
|
* `--mirostat-tau`: (AR only) the "surprise value" when performing mirostat sampling.
|
||||||
|
+ This simply uplifts the [original implementation](https://github.com/basusourya/mirostat/blob/master/mirostat.py) to perform it.
|
||||||
|
+ **!**NOTE**!**: This is incompatible with beam search sampling (for the meantime at least).
|
||||||
|
* `--mirostat-eta`: (Ar only) the "learning rate" during mirostat sampling applied to the maximum surprise.
|
||||||
|
|
||||||
## To-Do
|
## To-Do
|
||||||
|
|
||||||
* reduce load time for creating / preparing dataloaders (hint: remove use of `Path.glob` and `Path.rglob`).
|
* reduce load time for creating / preparing dataloaders (hint: remove use of `Path.glob` and `Path.rglob`).
|
||||||
* train and release a ***good*** model.
|
* train and release a ***good*** model.
|
||||||
* extend to multiple languages (VALL-E X) and ~~extend to~~ train SpeechX features.
|
* extend to multiple languages (VALL-E X) and ~~extend to~~ train SpeechX features.
|
||||||
|
+ This can easily be done with adding in additional embeddings + tokens, rather than cramming into the input prompt embedding.
|
||||||
## Notice
|
## Notice
|
||||||
|
|
||||||
- [EnCodec](https://github.com/facebookresearch/encodec) is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license.
|
- [EnCodec](https://github.com/facebookresearch/encodec) is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license.
|
||||||
|
|
|
@ -29,13 +29,28 @@ def main():
|
||||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||||
parser.add_argument("--beam-width", type=int, default=0)
|
parser.add_argument("--beam-width", type=int, default=0)
|
||||||
|
|
||||||
|
parser.add_argument("--mirostat-tau", type=float, default=0)
|
||||||
|
parser.add_argument("--mirostat-eta", type=float, default=0)
|
||||||
|
|
||||||
parser.add_argument("--device", type=str, default=None)
|
parser.add_argument("--device", type=str, default=None)
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
parser.add_argument("--dtype", type=str, default=None)
|
parser.add_argument("--dtype", type=str, default=None)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
|
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||||
tts.inference( text=args.text, references=args.references, out_path=args.out_path, input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, beam_width=args.beam_width )
|
tts.inference(
|
||||||
|
text=args.text,
|
||||||
|
references=args.references,
|
||||||
|
out_path=args.out_path,
|
||||||
|
input_prompt_length=args.input_prompt_length,
|
||||||
|
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||||
|
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||||
|
top_p=args.top_p, top_k=args.top_k,
|
||||||
|
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||||
|
length_penalty=args.length_penalty,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -291,7 +291,8 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# shuffle it up a bit
|
# shuffle it up a bit
|
||||||
prom_length = 0
|
prom_length = 0
|
||||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
#trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
|
||||||
|
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||||
|
|
||||||
for _ in range(cfg.dataset.max_prompts):
|
for _ in range(cfg.dataset.max_prompts):
|
||||||
path = random.choice(choices)
|
path = random.choice(choices)
|
||||||
|
|
|
@ -154,6 +154,8 @@ class TTS():
|
||||||
repetition_penalty_decay=0.0,
|
repetition_penalty_decay=0.0,
|
||||||
length_penalty=0.0,
|
length_penalty=0.0,
|
||||||
beam_width=0,
|
beam_width=0,
|
||||||
|
mirostat_tau=0,
|
||||||
|
mirostat_eta=0.1,
|
||||||
out_path=None
|
out_path=None
|
||||||
):
|
):
|
||||||
if out_path is None:
|
if out_path is None:
|
||||||
|
@ -166,9 +168,24 @@ class TTS():
|
||||||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||||
|
|
||||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||||
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty, sampling_beam_width=beam_width)
|
resps_list = self.ar(
|
||||||
|
text_list=[phns], proms_list=[prom], max_steps=max_ar_steps,
|
||||||
|
sampling_temperature=ar_temp,
|
||||||
|
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||||
|
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||||
|
sampling_length_penalty=length_penalty,
|
||||||
|
sampling_beam_width=beam_width,
|
||||||
|
sampling_mirostat_tau=mirostat_tau,
|
||||||
|
sampling_mirostat_eta=mirostat_eta,
|
||||||
|
)
|
||||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||||
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, max_levels=max_nar_levels, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty, sampling_beam_width=beam_width)
|
resps_list = self.nar(
|
||||||
|
text_list=[phns], proms_list=[prom], resps_list=resps_list,
|
||||||
|
max_levels=max_nar_levels,
|
||||||
|
sampling_temperature=nar_temp,
|
||||||
|
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||||
|
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||||
|
)
|
||||||
|
|
||||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
|
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
|
||||||
|
|
||||||
|
|
|
@ -99,6 +99,9 @@ class AR(Base):
|
||||||
sampling_repetition_penalty_decay: float = 0.0,
|
sampling_repetition_penalty_decay: float = 0.0,
|
||||||
sampling_length_penalty: float = 0.0,
|
sampling_length_penalty: float = 0.0,
|
||||||
sampling_beam_width: int = 0,
|
sampling_beam_width: int = 0,
|
||||||
|
|
||||||
|
sampling_mirostat_tau: float = 0.0,
|
||||||
|
sampling_mirostat_eta: float = 0.1,
|
||||||
):
|
):
|
||||||
if resps_list is not None:
|
if resps_list is not None:
|
||||||
if self.interleave:
|
if self.interleave:
|
||||||
|
@ -120,7 +123,10 @@ class AR(Base):
|
||||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
state = {} if cfg.inference.recurrent_forward else None
|
recurrent_state = {} if cfg.inference.recurrent_forward else None
|
||||||
|
mirostat = [
|
||||||
|
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
||||||
|
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||||
|
|
||||||
sampling_beam_width_use_logs = True
|
sampling_beam_width_use_logs = True
|
||||||
scores = [ 1.0 ] * sampling_beam_width
|
scores = [ 1.0 ] * sampling_beam_width
|
||||||
|
@ -136,7 +142,7 @@ class AR(Base):
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
|
|
||||||
state=state
|
state=recurrent_state
|
||||||
)
|
)
|
||||||
|
|
||||||
r = super().sample(
|
r = super().sample(
|
||||||
|
@ -150,10 +156,17 @@ class AR(Base):
|
||||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||||
length_penalty=sampling_length_penalty,
|
length_penalty=sampling_length_penalty,
|
||||||
beam_width=sampling_beam_width,
|
beam_width=sampling_beam_width,
|
||||||
|
|
||||||
|
mirostat=mirostat,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if mirostat is not None:
|
||||||
|
# r is the state
|
||||||
|
mirostat = r
|
||||||
|
# extract token from state
|
||||||
|
r = [ state["token"] for state in mirostat ]
|
||||||
# we do it here because the sampler will already expand our logits list
|
# we do it here because the sampler will already expand our logits list
|
||||||
if sampling_beam_width > 0:
|
elif sampling_beam_width > 0:
|
||||||
# expand tuple
|
# expand tuple
|
||||||
r, s = r
|
r, s = r
|
||||||
# first step, expand batch
|
# first step, expand batch
|
||||||
|
|
|
@ -85,6 +85,8 @@ class AR_NAR(Base):
|
||||||
sampling_repetition_penalty_decay: float = 0.0,
|
sampling_repetition_penalty_decay: float = 0.0,
|
||||||
sampling_length_penalty: float = 0.0,
|
sampling_length_penalty: float = 0.0,
|
||||||
sampling_beam_width: int = 0,
|
sampling_beam_width: int = 0,
|
||||||
|
sampling_mirostat_tau: float = 0.0,
|
||||||
|
sampling_mirostat_eta: float = 0.1,
|
||||||
):
|
):
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
|
@ -140,6 +142,7 @@ class AR_NAR(Base):
|
||||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||||
#length_penalty=sampling_length_penalty,
|
#length_penalty=sampling_length_penalty,
|
||||||
#beam_width=sampling_beam_width,
|
#beam_width=sampling_beam_width,
|
||||||
|
#mirostat=mirostat,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||||
|
@ -150,7 +153,10 @@ class AR_NAR(Base):
|
||||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
state = {} if cfg.inference.recurrent_forward else None
|
recurrent_state = {} if cfg.inference.recurrent_forward else None
|
||||||
|
mirostat = [
|
||||||
|
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
||||||
|
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||||
|
|
||||||
sampling_beam_width_use_logs = True
|
sampling_beam_width_use_logs = True
|
||||||
scores = [ 1.0 ] * sampling_beam_width
|
scores = [ 1.0 ] * sampling_beam_width
|
||||||
|
@ -166,7 +172,7 @@ class AR_NAR(Base):
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
|
|
||||||
state=state
|
state=recurrent_state
|
||||||
)
|
)
|
||||||
|
|
||||||
r = super().sample(
|
r = super().sample(
|
||||||
|
@ -180,10 +186,17 @@ class AR_NAR(Base):
|
||||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||||
length_penalty=sampling_length_penalty,
|
length_penalty=sampling_length_penalty,
|
||||||
beam_width=sampling_beam_width,
|
beam_width=sampling_beam_width,
|
||||||
|
|
||||||
|
mirostat=mirostat,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if mirostat is not None:
|
||||||
|
# r is the state
|
||||||
|
mirostat = r
|
||||||
|
# extract token from state
|
||||||
|
r = [ state["token"] for state in mirostat ]
|
||||||
# we do it here because the sampler will already expand our logits list
|
# we do it here because the sampler will already expand our logits list
|
||||||
if sampling_beam_width > 0:
|
elif sampling_beam_width > 0:
|
||||||
# expand tuple
|
# expand tuple
|
||||||
r, s = r
|
r, s = r
|
||||||
# first step, expand batch
|
# first step, expand batch
|
||||||
|
|
|
@ -136,6 +136,55 @@ def top_k_logits_list( logits_list, k ):
|
||||||
candidates[i] = tuple(t)
|
candidates[i] = tuple(t)
|
||||||
return candidates
|
return candidates
|
||||||
|
|
||||||
|
|
||||||
|
# Credit to: https://github.com/basusourya/mirostat/
|
||||||
|
# performs mirostat-based sampling
|
||||||
|
# logits: Tensor of logit probabilities
|
||||||
|
# state: the mirostat state
|
||||||
|
def mirostat_sample( logits, state = None ):
|
||||||
|
def compute_k(prob, n, tau):
|
||||||
|
num = 0
|
||||||
|
den = 0
|
||||||
|
for i in range(100):
|
||||||
|
b = prob[i]/prob[i+1]
|
||||||
|
t = (i+2)/(i+1)
|
||||||
|
num += math.log(b)*math.log(t)
|
||||||
|
den += math.log(t)**2
|
||||||
|
|
||||||
|
s = num/den
|
||||||
|
eps = s-1
|
||||||
|
k = ((eps*(2**(tau)))/(1-n**(-eps)))**(1/s)
|
||||||
|
k = round(k)
|
||||||
|
return k
|
||||||
|
|
||||||
|
if "max_surprise" not in state:
|
||||||
|
state["max_surprise"] = state["tau"] * 2
|
||||||
|
|
||||||
|
if "error_surprise" not in state:
|
||||||
|
state["error_surprise"] = 0
|
||||||
|
|
||||||
|
if "running_total_surprise" not in state:
|
||||||
|
state["running_total_surprise"] = 0
|
||||||
|
|
||||||
|
sorted_logits, sorted_indices = torch.sort( logits[-1, :], descending=True )
|
||||||
|
prob_original = torch.softmax( sorted_logits, dim=-1 ).tolist()
|
||||||
|
|
||||||
|
k = compute_k(prob_original, state["n"], state["max_surprise"]) + 1
|
||||||
|
|
||||||
|
sorted_logits = sorted_logits[0:k]
|
||||||
|
sorted_indices = sorted_indices[0:k]
|
||||||
|
prob_topk = torch.softmax(sorted_logits, dim = 0)
|
||||||
|
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True)
|
||||||
|
|
||||||
|
state["index_surprise"] = math.log2(1/prob_original[prev_i])
|
||||||
|
state["running_total_surprise"] += state["index_surprise"]
|
||||||
|
state["error_surprise"] = state["index_surprise"] - state["tau"]
|
||||||
|
state["max_surprise"] -= state["eta"] * state["error_surprise"]
|
||||||
|
state["token"] = sorted_indices[prev_i]
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
# automagically parses a batch-list and returns it as a list
|
# automagically parses a batch-list and returns it as a list
|
||||||
class Embedding(nn.Embedding):
|
class Embedding(nn.Embedding):
|
||||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
||||||
|
@ -455,6 +504,8 @@ class Base(nn.Module):
|
||||||
length_penalty: float = 0.0,
|
length_penalty: float = 0.0,
|
||||||
|
|
||||||
beam_width: int = 0,
|
beam_width: int = 0,
|
||||||
|
|
||||||
|
mirostat: list[dict] | None = None,
|
||||||
):
|
):
|
||||||
# (NAR) return the entire generated response
|
# (NAR) return the entire generated response
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
|
@ -480,6 +531,12 @@ class Base(nn.Module):
|
||||||
if top_k > 0 or top_p < 1.0:
|
if top_k > 0 or top_p < 1.0:
|
||||||
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
||||||
|
|
||||||
|
# do mirostat sampling
|
||||||
|
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
||||||
|
if mirostat is not None:
|
||||||
|
# mirostat sampling
|
||||||
|
return [ mirostat_sample(logit, state=state) for logit, state in zip(logits, mirostat) ]
|
||||||
|
|
||||||
# do beam search (naive implementation)
|
# do beam search (naive implementation)
|
||||||
# picks the top-k across all batches, and re-batches those resultant tokens
|
# picks the top-k across all batches, and re-batches those resultant tokens
|
||||||
# returns the logit scores as well to be P-concatted with the previous scores
|
# returns the logit scores as well to be P-concatted with the previous scores
|
||||||
|
@ -491,7 +548,6 @@ class Base(nn.Module):
|
||||||
return res, scores
|
return res, scores
|
||||||
|
|
||||||
# and sample
|
# and sample
|
||||||
# the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax
|
|
||||||
return [ Categorical(logits=logit).sample() for logit in logits ]
|
return [ Categorical(logits=logit).sample() for logit in logits ]
|
||||||
|
|
||||||
def example_usage():
|
def example_usage():
|
||||||
|
|
|
@ -68,8 +68,9 @@ class NAR(Base):
|
||||||
sampling_top_p: float = 1.0,
|
sampling_top_p: float = 1.0,
|
||||||
sampling_repetition_penalty: float = 1.0,
|
sampling_repetition_penalty: float = 1.0,
|
||||||
sampling_repetition_penalty_decay: float = 0.0,
|
sampling_repetition_penalty_decay: float = 0.0,
|
||||||
sampling_length_penalty: float = 0.0,
|
sampling_length_penalty: float = 0.0, # unused
|
||||||
sampling_beam_width: int = 0,
|
sampling_beam_width: int = 0, # unused
|
||||||
|
sampling_mirostat_tau: float = 0.0, # unused
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -140,6 +141,8 @@ class NAR(Base):
|
||||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||||
#length_penalty=sampling_length_penalty,
|
#length_penalty=sampling_length_penalty,
|
||||||
#beam_width=sampling_beam_width,
|
#beam_width=sampling_beam_width,
|
||||||
|
#mirostat_tau=sampling_mirostat_tau,
|
||||||
|
#mirostat_state=mirostat_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||||
|
|
|
@ -79,6 +79,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
||||||
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
||||||
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
|
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
|
||||||
|
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
|
||||||
|
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||||
|
@ -101,7 +103,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
repetition_penalty=args.repetition_penalty,
|
repetition_penalty=args.repetition_penalty,
|
||||||
repetition_penalty_decay=args.repetition_penalty_decay,
|
repetition_penalty_decay=args.repetition_penalty_decay,
|
||||||
length_penalty=args.length_penalty
|
length_penalty=args.length_penalty,
|
||||||
|
mirostat_tau=args.mirostat_tau,
|
||||||
|
mirostat_eta=args.mirostat_eta,
|
||||||
)
|
)
|
||||||
|
|
||||||
wav = wav.squeeze(0).cpu().numpy()
|
wav = wav.squeeze(0).cpu().numpy()
|
||||||
|
@ -183,20 +187,23 @@ with ui:
|
||||||
with gr.Column(scale=7):
|
with gr.Column(scale=7):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||||
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=3, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||||
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
||||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
|
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info="Limits the samples that are outside the top P%% of probabilities.")
|
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
||||||
layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
||||||
layout["inference"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
layout["inference"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||||
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||||
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||||
|
with gr.Row():
|
||||||
|
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=5.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
|
||||||
|
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
|
||||||
|
|
||||||
layout["inference"]["buttons"]["inference"].click(
|
layout["inference"]["buttons"]["inference"].click(
|
||||||
fn=do_inference,
|
fn=do_inference,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user