possible speedup for samplers that require a list of previous tokens (the DRY sampler made me realize that I should copy the tolist() thing from the rep pen sampler for everything else)

This commit is contained in:
mrq 2024-07-29 20:23:26 -05:00
parent 55b0121b1a
commit ebf848d249
3 changed files with 9 additions and 11 deletions

View File

@ -223,6 +223,7 @@ So far, this only allows you to load a different model without needing to restar
## To-Do
* [x] train and release a serviceable model for finetuning against.
- LoRA tests shows it's already very capable, although there's room for higher quality (possibly in better NAR training).
* [ ] train and release a ***good*** zero-shot model.
- this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now.
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
@ -234,6 +235,8 @@ So far, this only allows you to load a different model without needing to restar
* [ ] clean up the README, and document, document, document onto the wiki.
* [ ] extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
- this requires a good foundational model before extending it to transfer tasks onto.
* [ ] extend using [VALL-E 2](https://arxiv.org/pdf/2406.05370)'s features (grouped code modeling + repetition aware sampling)
- desu these don't seem to be worthwhile improvements, as inferencing is already rather fast, and RAS is just a fancy sampler.
* [ ] audio streaming
- this *technically* can work without any additional architecture changes, just clever tricks with sampling-then-decoding-to-audio.
- something similar to HiFiGAN (or the one for TorToiSe) trained on the last hidden states of the AR *might* also enable an alternate way for streaming.

View File

@ -1354,7 +1354,7 @@ class Base(nn.Module):
# perform repetition penalizing
if "len" not in self.capabilities:
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
# argmax instead
if temperature <= 0.0:
@ -1380,7 +1380,7 @@ class Base(nn.Module):
# do DRY sampling
if dry_multiplier > 0.0:
logits = [ dry_sampling(logit, previous=resps[:, -1], factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
# do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work

View File

@ -13,7 +13,7 @@ def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True )
return logits
unique = set()
priors = reversed(previous.tolist())
priors = reversed(previous)
for distance, token in enumerate(priors):
# skip if we're only applying the decay once
if one_time and token in unique:
@ -181,25 +181,20 @@ def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2
lengths = {}
for i, token in enumerate( previous ):
length = 1
while True:
while length < max(allowed_length, 50):
j = i - length
# Start of input reached.
if j < 0:
break
previous_token = previous[-length-1].item()
# Start of match reached.
if previous[j] != previous_token:
if previous[j] != previous[-length-1]:
break
length += 1
if token in lengths:
lengths[token] = max(length, lengths[token])
else:
lengths[token] = length
lengths[token] = max(length, lengths[token]) if token in lengths else length
for token, length in lengths.items():
if length < allowed_length: