added mirostat sampling (given a partially trained model, it got far decent output than I expected, need to test on a better trained model)

This commit is contained in:
mrq 2023-09-18 18:55:41 -05:00
parent 2567e082b5
commit a6bfe43590
9 changed files with 149 additions and 18 deletions

View File

@ -140,13 +140,19 @@ And some experimental sampling flags you can use too (your mileage will ***defin
* `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use.
* `--repetition-penalty-decay`: modifies the above factor applied to scale based on how far away it is in the past sequence.
* `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky due to the AR already being well correlated with the length.
* `--beam-width`: (AR only) specifies the number of branches to search through for beam sampling. This is a very naive implementation that's effectively just greedy sampling across `B` spaces.
* `--beam-width`: (AR only) specifies the number of branches to search through for beam sampling.
+ This is a very naive implementation that's effectively just greedy sampling across `B` spaces.
* `--mirostat-tau`: (AR only) the "surprise value" when performing mirostat sampling.
+ This simply uplifts the [original implementation](https://github.com/basusourya/mirostat/blob/master/mirostat.py) to perform it.
+ **!**NOTE**!**: This is incompatible with beam search sampling (for the meantime at least).
* `--mirostat-eta`: (Ar only) the "learning rate" during mirostat sampling applied to the maximum surprise.
## To-Do
* reduce load time for creating / preparing dataloaders (hint: remove use of `Path.glob` and `Path.rglob`).
* train and release a ***good*** model.
* extend to multiple languages (VALL-E X) and ~~extend to~~ train SpeechX features.
+ This can easily be done with adding in additional embeddings + tokens, rather than cramming into the input prompt embedding.
## Notice
- [EnCodec](https://github.com/facebookresearch/encodec) is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license.

View File

@ -28,6 +28,9 @@ def main():
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0)
parser.add_argument("--beam-width", type=int, default=0)
parser.add_argument("--mirostat-tau", type=float, default=0)
parser.add_argument("--mirostat-eta", type=float, default=0)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--amp", action="store_true")
@ -35,7 +38,19 @@ def main():
args = parser.parse_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
tts.inference( text=args.text, references=args.references, out_path=args.out_path, input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, beam_width=args.beam_width )
tts.inference(
text=args.text,
references=args.references,
out_path=args.out_path,
input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
top_p=args.top_p, top_k=args.top_k,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta
)
if __name__ == "__main__":
main()

View File

@ -291,7 +291,8 @@ class Dataset(_Dataset):
# shuffle it up a bit
prom_length = 0
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
#trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
for _ in range(cfg.dataset.max_prompts):
path = random.choice(choices)

View File

@ -154,6 +154,8 @@ class TTS():
repetition_penalty_decay=0.0,
length_penalty=0.0,
beam_width=0,
mirostat_tau=0,
mirostat_eta=0.1,
out_path=None
):
if out_path is None:
@ -166,9 +168,24 @@ class TTS():
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty, sampling_beam_width=beam_width)
resps_list = self.ar(
text_list=[phns], proms_list=[prom], max_steps=max_ar_steps,
sampling_temperature=ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_length_penalty=length_penalty,
sampling_beam_width=beam_width,
sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta,
)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, max_levels=max_nar_levels, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty, sampling_beam_width=beam_width)
resps_list = self.nar(
text_list=[phns], proms_list=[prom], resps_list=resps_list,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
)
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)

View File

@ -99,6 +99,9 @@ class AR(Base):
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
):
if resps_list is not None:
if self.interleave:
@ -120,7 +123,10 @@ class AR(Base):
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
stopped = torch.zeros(batch_size, device=device).bool()
state = {} if cfg.inference.recurrent_forward else None
recurrent_state = {} if cfg.inference.recurrent_forward else None
mirostat = [
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
] * batch_size if sampling_mirostat_tau > 0.0 else None
sampling_beam_width_use_logs = True
scores = [ 1.0 ] * sampling_beam_width
@ -136,7 +142,7 @@ class AR(Base):
proms_list=proms_list,
resps_list=resps_list,
state=state
state=recurrent_state
)
r = super().sample(
@ -150,10 +156,17 @@ class AR(Base):
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,
beam_width=sampling_beam_width,
mirostat=mirostat,
)
if mirostat is not None:
# r is the state
mirostat = r
# extract token from state
r = [ state["token"] for state in mirostat ]
# we do it here because the sampler will already expand our logits list
if sampling_beam_width > 0:
elif sampling_beam_width > 0:
# expand tuple
r, s = r
# first step, expand batch

View File

@ -85,6 +85,8 @@ class AR_NAR(Base):
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
):
device = text_list[0].device
batch_size = len(text_list)
@ -140,6 +142,7 @@ class AR_NAR(Base):
repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
#mirostat=mirostat,
)
prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
@ -150,7 +153,10 @@ class AR_NAR(Base):
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
stopped = torch.zeros(batch_size, device=device).bool()
state = {} if cfg.inference.recurrent_forward else None
recurrent_state = {} if cfg.inference.recurrent_forward else None
mirostat = [
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
] * batch_size if sampling_mirostat_tau > 0.0 else None
sampling_beam_width_use_logs = True
scores = [ 1.0 ] * sampling_beam_width
@ -166,7 +172,7 @@ class AR_NAR(Base):
proms_list=proms_list,
resps_list=resps_list,
state=state
state=recurrent_state
)
r = super().sample(
@ -180,10 +186,17 @@ class AR_NAR(Base):
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,
beam_width=sampling_beam_width,
mirostat=mirostat,
)
if mirostat is not None:
# r is the state
mirostat = r
# extract token from state
r = [ state["token"] for state in mirostat ]
# we do it here because the sampler will already expand our logits list
if sampling_beam_width > 0:
elif sampling_beam_width > 0:
# expand tuple
r, s = r
# first step, expand batch

View File

@ -136,6 +136,55 @@ def top_k_logits_list( logits_list, k ):
candidates[i] = tuple(t)
return candidates
# Credit to: https://github.com/basusourya/mirostat/
# performs mirostat-based sampling
# logits: Tensor of logit probabilities
# state: the mirostat state
def mirostat_sample( logits, state = None ):
def compute_k(prob, n, tau):
num = 0
den = 0
for i in range(100):
b = prob[i]/prob[i+1]
t = (i+2)/(i+1)
num += math.log(b)*math.log(t)
den += math.log(t)**2
s = num/den
eps = s-1
k = ((eps*(2**(tau)))/(1-n**(-eps)))**(1/s)
k = round(k)
return k
if "max_surprise" not in state:
state["max_surprise"] = state["tau"] * 2
if "error_surprise" not in state:
state["error_surprise"] = 0
if "running_total_surprise" not in state:
state["running_total_surprise"] = 0
sorted_logits, sorted_indices = torch.sort( logits[-1, :], descending=True )
prob_original = torch.softmax( sorted_logits, dim=-1 ).tolist()
k = compute_k(prob_original, state["n"], state["max_surprise"]) + 1
sorted_logits = sorted_logits[0:k]
sorted_indices = sorted_indices[0:k]
prob_topk = torch.softmax(sorted_logits, dim = 0)
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True)
state["index_surprise"] = math.log2(1/prob_original[prev_i])
state["running_total_surprise"] += state["index_surprise"]
state["error_surprise"] = state["index_surprise"] - state["tau"]
state["max_surprise"] -= state["eta"] * state["error_surprise"]
state["token"] = sorted_indices[prev_i]
return state
# automagically parses a batch-list and returns it as a list
class Embedding(nn.Embedding):
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
@ -455,6 +504,8 @@ class Base(nn.Module):
length_penalty: float = 0.0,
beam_width: int = 0,
mirostat: list[dict] | None = None,
):
# (NAR) return the entire generated response
if quant_levels is not None:
@ -480,6 +531,12 @@ class Base(nn.Module):
if top_k > 0 or top_p < 1.0:
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
# do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
if mirostat is not None:
# mirostat sampling
return [ mirostat_sample(logit, state=state) for logit, state in zip(logits, mirostat) ]
# do beam search (naive implementation)
# picks the top-k across all batches, and re-batches those resultant tokens
# returns the logit scores as well to be P-concatted with the previous scores
@ -491,7 +548,6 @@ class Base(nn.Module):
return res, scores
# and sample
# the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax
return [ Categorical(logits=logit).sample() for logit in logits ]
def example_usage():

View File

@ -68,8 +68,9 @@ class NAR(Base):
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_length_penalty: float = 0.0, # unused
sampling_beam_width: int = 0, # unused
sampling_mirostat_tau: float = 0.0, # unused
):
"""
Args:
@ -140,6 +141,8 @@ class NAR(Base):
repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
#mirostat_tau=sampling_mirostat_tau,
#mirostat_state=mirostat_state,
)
prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) for rs, r in zip(prev_list, resps_list) ]

View File

@ -79,6 +79,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
@ -101,7 +103,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty
length_penalty=args.length_penalty,
mirostat_tau=args.mirostat_tau,
mirostat_eta=args.mirostat_eta,
)
wav = wav.squeeze(0).cpu().numpy()
@ -183,20 +187,23 @@ with ui:
with gr.Column(scale=7):
with gr.Row():
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=3, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
with gr.Row():
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info="Limits the samples that are outside the top P%% of probabilities.")
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
layout["inference"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
with gr.Row():
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
with gr.Row():
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=5.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
layout["inference"]["buttons"]["inference"].click(
fn=do_inference,