From ebf848d24925b64742eab451de91fa5e31d0ade6 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 29 Jul 2024 20:23:26 -0500 Subject: [PATCH] possible speedup for samplers that require a list of previous tokens (the DRY sampler made me realize that I should copy the tolist() thing from the rep pen sampler for everything else) --- README.md | 3 +++ vall_e/models/base.py | 4 ++-- vall_e/samplers.py | 13 ++++--------- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index d845ad7..e244769 100755 --- a/README.md +++ b/README.md @@ -223,6 +223,7 @@ So far, this only allows you to load a different model without needing to restar ## To-Do * [x] train and release a serviceable model for finetuning against. + - LoRA tests shows it's already very capable, although there's room for higher quality (possibly in better NAR training). * [ ] train and release a ***good*** zero-shot model. - this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now. * [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning) @@ -234,6 +235,8 @@ So far, this only allows you to load a different model without needing to restar * [ ] clean up the README, and document, document, document onto the wiki. * [ ] extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)). - this requires a good foundational model before extending it to transfer tasks onto. +* [ ] extend using [VALL-E 2](https://arxiv.org/pdf/2406.05370)'s features (grouped code modeling + repetition aware sampling) + - desu these don't seem to be worthwhile improvements, as inferencing is already rather fast, and RAS is just a fancy sampler. * [ ] audio streaming - this *technically* can work without any additional architecture changes, just clever tricks with sampling-then-decoding-to-audio. - something similar to HiFiGAN (or the one for TorToiSe) trained on the last hidden states of the AR *might* also enable an alternate way for streaming. diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 6dd223c..68f1fb6 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1354,7 +1354,7 @@ class Base(nn.Module): # perform repetition penalizing if "len" not in self.capabilities: - logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] + logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] # argmax instead if temperature <= 0.0: @@ -1380,7 +1380,7 @@ class Base(nn.Module): # do DRY sampling if dry_multiplier > 0.0: - logits = [ dry_sampling(logit, previous=resps[:, -1], factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ] + logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ] # do mirostat sampling # currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work diff --git a/vall_e/samplers.py b/vall_e/samplers.py index dcbf856..74ef9f0 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -13,7 +13,7 @@ def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ) return logits unique = set() - priors = reversed(previous.tolist()) + priors = reversed(previous) for distance, token in enumerate(priors): # skip if we're only applying the decay once if one_time and token in unique: @@ -181,25 +181,20 @@ def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2 lengths = {} for i, token in enumerate( previous ): length = 1 - while True: + while length < max(allowed_length, 50): j = i - length # Start of input reached. if j < 0: break - previous_token = previous[-length-1].item() - # Start of match reached. - if previous[j] != previous_token: + if previous[j] != previous[-length-1]: break length += 1 - if token in lengths: - lengths[token] = max(length, lengths[token]) - else: - lengths[token] = length + lengths[token] = max(length, lengths[token]) if token in lengths else length for token, length in lengths.items(): if length < allowed_length: