added mirostat sampling (given a partially trained model, it got far decent output than I expected, need to test on a better trained model)
This commit is contained in:
parent
2567e082b5
commit
a6bfe43590
10
README.md
10
README.md
|
@ -140,13 +140,19 @@ And some experimental sampling flags you can use too (your mileage will ***defin
|
|||
* `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use.
|
||||
* `--repetition-penalty-decay`: modifies the above factor applied to scale based on how far away it is in the past sequence.
|
||||
* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky due to the AR already being well correlated with the length.
|
||||
* `--beam-width`: (AR only) specifies the number of branches to search through for beam sampling. This is a very naive implementation that's effectively just greedy sampling across `B` spaces.
|
||||
* `--beam-width`: (AR only) specifies the number of branches to search through for beam sampling.
|
||||
+ This is a very naive implementation that's effectively just greedy sampling across `B` spaces.
|
||||
* `--mirostat-tau`: (AR only) the "surprise value" when performing mirostat sampling.
|
||||
+ This simply uplifts the [original implementation](https://github.com/basusourya/mirostat/blob/master/mirostat.py) to perform it.
|
||||
+ **!**NOTE**!**: This is incompatible with beam search sampling (for the meantime at least).
|
||||
* `--mirostat-eta`: (Ar only) the "learning rate" during mirostat sampling applied to the maximum surprise.
|
||||
|
||||
## To-Do
|
||||
|
||||
* reduce load time for creating / preparing dataloaders (hint: remove use of `Path.glob` and `Path.rglob`).
|
||||
* train and release a ***good*** model.
|
||||
* extend to multiple languages (VALL-E X) and ~~extend to~~ train SpeechX features.
|
||||
|
||||
+ This can easily be done with adding in additional embeddings + tokens, rather than cramming into the input prompt embedding.
|
||||
## Notice
|
||||
|
||||
- [EnCodec](https://github.com/facebookresearch/encodec) is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license.
|
||||
|
|
|
@ -28,6 +28,9 @@ def main():
|
|||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--beam-width", type=int, default=0)
|
||||
|
||||
parser.add_argument("--mirostat-tau", type=float, default=0)
|
||||
parser.add_argument("--mirostat-eta", type=float, default=0)
|
||||
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
|
@ -35,7 +38,19 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||
tts.inference( text=args.text, references=args.references, out_path=args.out_path, input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, beam_width=args.beam_width )
|
||||
tts.inference(
|
||||
text=args.text,
|
||||
references=args.references,
|
||||
out_path=args.out_path,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||
top_p=args.top_p, top_k=args.top_k,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -291,7 +291,8 @@ class Dataset(_Dataset):
|
|||
|
||||
# shuffle it up a bit
|
||||
prom_length = 0
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||
#trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||
|
||||
for _ in range(cfg.dataset.max_prompts):
|
||||
path = random.choice(choices)
|
||||
|
|
|
@ -154,6 +154,8 @@ class TTS():
|
|||
repetition_penalty_decay=0.0,
|
||||
length_penalty=0.0,
|
||||
beam_width=0,
|
||||
mirostat_tau=0,
|
||||
mirostat_eta=0.1,
|
||||
out_path=None
|
||||
):
|
||||
if out_path is None:
|
||||
|
@ -166,9 +168,24 @@ class TTS():
|
|||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty, sampling_beam_width=beam_width)
|
||||
resps_list = self.ar(
|
||||
text_list=[phns], proms_list=[prom], max_steps=max_ar_steps,
|
||||
sampling_temperature=ar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
sampling_length_penalty=length_penalty,
|
||||
sampling_beam_width=beam_width,
|
||||
sampling_mirostat_tau=mirostat_tau,
|
||||
sampling_mirostat_eta=mirostat_eta,
|
||||
)
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, max_levels=max_nar_levels, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty, sampling_beam_width=beam_width)
|
||||
resps_list = self.nar(
|
||||
text_list=[phns], proms_list=[prom], resps_list=resps_list,
|
||||
max_levels=max_nar_levels,
|
||||
sampling_temperature=nar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
)
|
||||
|
||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
|
||||
|
||||
|
|
|
@ -99,6 +99,9 @@ class AR(Base):
|
|||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
sampling_beam_width: int = 0,
|
||||
|
||||
sampling_mirostat_tau: float = 0.0,
|
||||
sampling_mirostat_eta: float = 0.1,
|
||||
):
|
||||
if resps_list is not None:
|
||||
if self.interleave:
|
||||
|
@ -120,7 +123,10 @@ class AR(Base):
|
|||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
state = {} if cfg.inference.recurrent_forward else None
|
||||
recurrent_state = {} if cfg.inference.recurrent_forward else None
|
||||
mirostat = [
|
||||
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
||||
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||
|
||||
sampling_beam_width_use_logs = True
|
||||
scores = [ 1.0 ] * sampling_beam_width
|
||||
|
@ -136,7 +142,7 @@ class AR(Base):
|
|||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
state=state
|
||||
state=recurrent_state
|
||||
)
|
||||
|
||||
r = super().sample(
|
||||
|
@ -150,10 +156,17 @@ class AR(Base):
|
|||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
length_penalty=sampling_length_penalty,
|
||||
beam_width=sampling_beam_width,
|
||||
|
||||
mirostat=mirostat,
|
||||
)
|
||||
|
||||
if mirostat is not None:
|
||||
# r is the state
|
||||
mirostat = r
|
||||
# extract token from state
|
||||
r = [ state["token"] for state in mirostat ]
|
||||
# we do it here because the sampler will already expand our logits list
|
||||
if sampling_beam_width > 0:
|
||||
elif sampling_beam_width > 0:
|
||||
# expand tuple
|
||||
r, s = r
|
||||
# first step, expand batch
|
||||
|
|
|
@ -85,6 +85,8 @@ class AR_NAR(Base):
|
|||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
sampling_beam_width: int = 0,
|
||||
sampling_mirostat_tau: float = 0.0,
|
||||
sampling_mirostat_eta: float = 0.1,
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
@ -140,6 +142,7 @@ class AR_NAR(Base):
|
|||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#length_penalty=sampling_length_penalty,
|
||||
#beam_width=sampling_beam_width,
|
||||
#mirostat=mirostat,
|
||||
)
|
||||
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
@ -150,7 +153,10 @@ class AR_NAR(Base):
|
|||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
state = {} if cfg.inference.recurrent_forward else None
|
||||
recurrent_state = {} if cfg.inference.recurrent_forward else None
|
||||
mirostat = [
|
||||
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
||||
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||
|
||||
sampling_beam_width_use_logs = True
|
||||
scores = [ 1.0 ] * sampling_beam_width
|
||||
|
@ -166,7 +172,7 @@ class AR_NAR(Base):
|
|||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
state=state
|
||||
state=recurrent_state
|
||||
)
|
||||
|
||||
r = super().sample(
|
||||
|
@ -180,10 +186,17 @@ class AR_NAR(Base):
|
|||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
length_penalty=sampling_length_penalty,
|
||||
beam_width=sampling_beam_width,
|
||||
|
||||
mirostat=mirostat,
|
||||
)
|
||||
|
||||
if mirostat is not None:
|
||||
# r is the state
|
||||
mirostat = r
|
||||
# extract token from state
|
||||
r = [ state["token"] for state in mirostat ]
|
||||
# we do it here because the sampler will already expand our logits list
|
||||
if sampling_beam_width > 0:
|
||||
elif sampling_beam_width > 0:
|
||||
# expand tuple
|
||||
r, s = r
|
||||
# first step, expand batch
|
||||
|
|
|
@ -136,6 +136,55 @@ def top_k_logits_list( logits_list, k ):
|
|||
candidates[i] = tuple(t)
|
||||
return candidates
|
||||
|
||||
|
||||
# Credit to: https://github.com/basusourya/mirostat/
|
||||
# performs mirostat-based sampling
|
||||
# logits: Tensor of logit probabilities
|
||||
# state: the mirostat state
|
||||
def mirostat_sample( logits, state = None ):
|
||||
def compute_k(prob, n, tau):
|
||||
num = 0
|
||||
den = 0
|
||||
for i in range(100):
|
||||
b = prob[i]/prob[i+1]
|
||||
t = (i+2)/(i+1)
|
||||
num += math.log(b)*math.log(t)
|
||||
den += math.log(t)**2
|
||||
|
||||
s = num/den
|
||||
eps = s-1
|
||||
k = ((eps*(2**(tau)))/(1-n**(-eps)))**(1/s)
|
||||
k = round(k)
|
||||
return k
|
||||
|
||||
if "max_surprise" not in state:
|
||||
state["max_surprise"] = state["tau"] * 2
|
||||
|
||||
if "error_surprise" not in state:
|
||||
state["error_surprise"] = 0
|
||||
|
||||
if "running_total_surprise" not in state:
|
||||
state["running_total_surprise"] = 0
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort( logits[-1, :], descending=True )
|
||||
prob_original = torch.softmax( sorted_logits, dim=-1 ).tolist()
|
||||
|
||||
k = compute_k(prob_original, state["n"], state["max_surprise"]) + 1
|
||||
|
||||
sorted_logits = sorted_logits[0:k]
|
||||
sorted_indices = sorted_indices[0:k]
|
||||
prob_topk = torch.softmax(sorted_logits, dim = 0)
|
||||
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True)
|
||||
|
||||
state["index_surprise"] = math.log2(1/prob_original[prev_i])
|
||||
state["running_total_surprise"] += state["index_surprise"]
|
||||
state["error_surprise"] = state["index_surprise"] - state["tau"]
|
||||
state["max_surprise"] -= state["eta"] * state["error_surprise"]
|
||||
state["token"] = sorted_indices[prev_i]
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# automagically parses a batch-list and returns it as a list
|
||||
class Embedding(nn.Embedding):
|
||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
||||
|
@ -455,6 +504,8 @@ class Base(nn.Module):
|
|||
length_penalty: float = 0.0,
|
||||
|
||||
beam_width: int = 0,
|
||||
|
||||
mirostat: list[dict] | None = None,
|
||||
):
|
||||
# (NAR) return the entire generated response
|
||||
if quant_levels is not None:
|
||||
|
@ -480,6 +531,12 @@ class Base(nn.Module):
|
|||
if top_k > 0 or top_p < 1.0:
|
||||
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
||||
|
||||
# do mirostat sampling
|
||||
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
||||
if mirostat is not None:
|
||||
# mirostat sampling
|
||||
return [ mirostat_sample(logit, state=state) for logit, state in zip(logits, mirostat) ]
|
||||
|
||||
# do beam search (naive implementation)
|
||||
# picks the top-k across all batches, and re-batches those resultant tokens
|
||||
# returns the logit scores as well to be P-concatted with the previous scores
|
||||
|
@ -491,7 +548,6 @@ class Base(nn.Module):
|
|||
return res, scores
|
||||
|
||||
# and sample
|
||||
# the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax
|
||||
return [ Categorical(logits=logit).sample() for logit in logits ]
|
||||
|
||||
def example_usage():
|
||||
|
|
|
@ -68,8 +68,9 @@ class NAR(Base):
|
|||
sampling_top_p: float = 1.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
sampling_beam_width: int = 0,
|
||||
sampling_length_penalty: float = 0.0, # unused
|
||||
sampling_beam_width: int = 0, # unused
|
||||
sampling_mirostat_tau: float = 0.0, # unused
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -140,6 +141,8 @@ class NAR(Base):
|
|||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#length_penalty=sampling_length_penalty,
|
||||
#beam_width=sampling_beam_width,
|
||||
#mirostat_tau=sampling_mirostat_tau,
|
||||
#mirostat_state=mirostat_state,
|
||||
)
|
||||
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
|
|
@ -79,6 +79,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
||||
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
||||
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
|
||||
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
|
||||
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||
|
@ -101,7 +103,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
top_k=args.top_k,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty
|
||||
length_penalty=args.length_penalty,
|
||||
mirostat_tau=args.mirostat_tau,
|
||||
mirostat_eta=args.mirostat_eta,
|
||||
)
|
||||
|
||||
wav = wav.squeeze(0).cpu().numpy()
|
||||
|
@ -183,20 +187,23 @@ with ui:
|
|||
with gr.Column(scale=7):
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=3, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
|
||||
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info="Limits the samples that are outside the top P%% of probabilities.")
|
||||
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
||||
layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
||||
layout["inference"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=5.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
|
||||
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
|
||||
|
||||
layout["inference"]["buttons"]["inference"].click(
|
||||
fn=do_inference,
|
||||
|
|
Loading…
Reference in New Issue
Block a user