possible speedup for samplers that require a list of previous tokens (the DRY sampler made me realize that I should copy the tolist() thing from the rep pen sampler for everything else)
This commit is contained in:
parent
55b0121b1a
commit
ebf848d249
|
@ -223,6 +223,7 @@ So far, this only allows you to load a different model without needing to restar
|
|||
## To-Do
|
||||
|
||||
* [x] train and release a serviceable model for finetuning against.
|
||||
- LoRA tests shows it's already very capable, although there's room for higher quality (possibly in better NAR training).
|
||||
* [ ] train and release a ***good*** zero-shot model.
|
||||
- this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now.
|
||||
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
|
||||
|
@ -234,6 +235,8 @@ So far, this only allows you to load a different model without needing to restar
|
|||
* [ ] clean up the README, and document, document, document onto the wiki.
|
||||
* [ ] extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
|
||||
- this requires a good foundational model before extending it to transfer tasks onto.
|
||||
* [ ] extend using [VALL-E 2](https://arxiv.org/pdf/2406.05370)'s features (grouped code modeling + repetition aware sampling)
|
||||
- desu these don't seem to be worthwhile improvements, as inferencing is already rather fast, and RAS is just a fancy sampler.
|
||||
* [ ] audio streaming
|
||||
- this *technically* can work without any additional architecture changes, just clever tricks with sampling-then-decoding-to-audio.
|
||||
- something similar to HiFiGAN (or the one for TorToiSe) trained on the last hidden states of the AR *might* also enable an alternate way for streaming.
|
||||
|
|
|
@ -1354,7 +1354,7 @@ class Base(nn.Module):
|
|||
|
||||
# perform repetition penalizing
|
||||
if "len" not in self.capabilities:
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||
|
||||
# argmax instead
|
||||
if temperature <= 0.0:
|
||||
|
@ -1380,7 +1380,7 @@ class Base(nn.Module):
|
|||
|
||||
# do DRY sampling
|
||||
if dry_multiplier > 0.0:
|
||||
logits = [ dry_sampling(logit, previous=resps[:, -1], factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
|
||||
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ]
|
||||
|
||||
# do mirostat sampling
|
||||
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
||||
|
|
|
@ -13,7 +13,7 @@ def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True )
|
|||
return logits
|
||||
|
||||
unique = set()
|
||||
priors = reversed(previous.tolist())
|
||||
priors = reversed(previous)
|
||||
for distance, token in enumerate(priors):
|
||||
# skip if we're only applying the decay once
|
||||
if one_time and token in unique:
|
||||
|
@ -181,25 +181,20 @@ def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2
|
|||
lengths = {}
|
||||
for i, token in enumerate( previous ):
|
||||
length = 1
|
||||
while True:
|
||||
while length < max(allowed_length, 50):
|
||||
j = i - length
|
||||
|
||||
# Start of input reached.
|
||||
if j < 0:
|
||||
break
|
||||
|
||||
previous_token = previous[-length-1].item()
|
||||
|
||||
# Start of match reached.
|
||||
if previous[j] != previous_token:
|
||||
if previous[j] != previous[-length-1]:
|
||||
break
|
||||
|
||||
length += 1
|
||||
|
||||
if token in lengths:
|
||||
lengths[token] = max(length, lengths[token])
|
||||
else:
|
||||
lengths[token] = length
|
||||
lengths[token] = max(length, lengths[token]) if token in lengths else length
|
||||
|
||||
for token, length in lengths.items():
|
||||
if length < allowed_length:
|
||||
|
|
Loading…
Reference in New Issue
Block a user