possible speedup for samplers that require a list of previous tokens (the DRY sampler made me realize that I should copy the tolist() thing from the rep pen sampler for everything else)
This commit is contained in:
parent
55b0121b1a
commit
ebf848d249
|
@ -223,6 +223,7 @@ So far, this only allows you to load a different model without needing to restar
|
||||||
## To-Do
|
## To-Do
|
||||||
|
|
||||||
* [x] train and release a serviceable model for finetuning against.
|
* [x] train and release a serviceable model for finetuning against.
|
||||||
|
- LoRA tests shows it's already very capable, although there's room for higher quality (possibly in better NAR training).
|
||||||
* [ ] train and release a ***good*** zero-shot model.
|
* [ ] train and release a ***good*** zero-shot model.
|
||||||
- this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now.
|
- this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now.
|
||||||
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
|
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
|
||||||
|
@ -234,6 +235,8 @@ So far, this only allows you to load a different model without needing to restar
|
||||||
* [ ] clean up the README, and document, document, document onto the wiki.
|
* [ ] clean up the README, and document, document, document onto the wiki.
|
||||||
* [ ] extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
|
* [ ] extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
|
||||||
- this requires a good foundational model before extending it to transfer tasks onto.
|
- this requires a good foundational model before extending it to transfer tasks onto.
|
||||||
|
* [ ] extend using [VALL-E 2](https://arxiv.org/pdf/2406.05370)'s features (grouped code modeling + repetition aware sampling)
|
||||||
|
- desu these don't seem to be worthwhile improvements, as inferencing is already rather fast, and RAS is just a fancy sampler.
|
||||||
* [ ] audio streaming
|
* [ ] audio streaming
|
||||||
- this *technically* can work without any additional architecture changes, just clever tricks with sampling-then-decoding-to-audio.
|
- this *technically* can work without any additional architecture changes, just clever tricks with sampling-then-decoding-to-audio.
|
||||||
- something similar to HiFiGAN (or the one for TorToiSe) trained on the last hidden states of the AR *might* also enable an alternate way for streaming.
|
- something similar to HiFiGAN (or the one for TorToiSe) trained on the last hidden states of the AR *might* also enable an alternate way for streaming.
|
||||||
|
|
|
@ -1354,7 +1354,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# perform repetition penalizing
|
# perform repetition penalizing
|
||||||
if "len" not in self.capabilities:
|
if "len" not in self.capabilities:
|
||||||
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||||
|
|
||||||
# argmax instead
|
# argmax instead
|
||||||
if temperature <= 0.0:
|
if temperature <= 0.0:
|
||||||
|
@ -1380,7 +1380,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# do DRY sampling
|
# do DRY sampling
|
||||||
if dry_multiplier > 0.0:
|
if dry_multiplier > 0.0:
|
||||||
logits = [ dry_sampling(logit, previous=resps[:, -1], factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
|
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
|
||||||
|
|
||||||
# do mirostat sampling
|
# do mirostat sampling
|
||||||
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
||||||
|
|
|
@ -13,7 +13,7 @@ def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True )
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
unique = set()
|
unique = set()
|
||||||
priors = reversed(previous.tolist())
|
priors = reversed(previous)
|
||||||
for distance, token in enumerate(priors):
|
for distance, token in enumerate(priors):
|
||||||
# skip if we're only applying the decay once
|
# skip if we're only applying the decay once
|
||||||
if one_time and token in unique:
|
if one_time and token in unique:
|
||||||
|
@ -181,25 +181,20 @@ def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2
|
||||||
lengths = {}
|
lengths = {}
|
||||||
for i, token in enumerate( previous ):
|
for i, token in enumerate( previous ):
|
||||||
length = 1
|
length = 1
|
||||||
while True:
|
while length < max(allowed_length, 50):
|
||||||
j = i - length
|
j = i - length
|
||||||
|
|
||||||
# Start of input reached.
|
# Start of input reached.
|
||||||
if j < 0:
|
if j < 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
previous_token = previous[-length-1].item()
|
|
||||||
|
|
||||||
# Start of match reached.
|
# Start of match reached.
|
||||||
if previous[j] != previous_token:
|
if previous[j] != previous[-length-1]:
|
||||||
break
|
break
|
||||||
|
|
||||||
length += 1
|
length += 1
|
||||||
|
|
||||||
if token in lengths:
|
lengths[token] = max(length, lengths[token]) if token in lengths else length
|
||||||
lengths[token] = max(length, lengths[token])
|
|
||||||
else:
|
|
||||||
lengths[token] = length
|
|
||||||
|
|
||||||
for token, length in lengths.items():
|
for token, length in lengths.items():
|
||||||
if length < allowed_length:
|
if length < allowed_length:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user