nothing could go wrong part 2 (reverted and rewrote commits since there was a nasty regression)

This commit is contained in:
mrq 2025-03-25 23:06:16 -05:00
parent aa8b32d97e
commit 8641c87611
14 changed files with 41 additions and 15 deletions

View File

@ -145,6 +145,9 @@ These settings should be avoided:
To be evaluated thoroughly.
* The smaller model seems to have hit its capacity limit, while the larger model is slowly improving (although objective metrics are not noted).
* The model seems pretty quick, even for the large model.
* The smaller model seems small enough for CPU-only inferencing
* Despite its poor zero-shot performance, it could be perfectly fine for finetuning.
At a glance, compared to the prior model setup, this implementation allows for the model to better represent speech as it's able to see the entire signal and account for it in its latent space, rather than only specific levels of it.

0
vall_e/__init__.py Executable file → Normal file
View File

0
vall_e/__main__.py Executable file → Normal file
View File

0
vall_e/config.py Executable file → Normal file
View File

0
vall_e/data.py Executable file → Normal file
View File

0
vall_e/export.py Executable file → Normal file
View File

0
vall_e/inference.py Executable file → Normal file
View File

0
vall_e/models/__init__.py Executable file → Normal file
View File

View File

@ -249,8 +249,16 @@ class AR_NAR(Base):
use_lora=None,
**sampling_kwargs,
):
device = phns_list[0].device
batch_size = len(phns_list)
# deduce batch_size
if phns_list:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list:
device = text_list[0].device
batch_size = len(text_list)
elif proms_list:
device = proms_list[0].device
batch_size = len(proms_list)
if quant_levels is None:
level = 0
@ -384,6 +392,7 @@ class AR_NAR(Base):
# setup inputs
inputs = super().inputs(
phns_list=phns_list,
text_list=text_list,
proms_list=proms_list,
resps_list=input_resps_list,
lang_list=lang_list,
@ -400,7 +409,8 @@ class AR_NAR(Base):
if cfg_strength > 0:
null_inputs = super().inputs(
phns_list=null_text,
phns_list=null_text if phns_list is not None else None,
text_list=null_text if text_list is not None else None,
proms_list=null_prom,
resps_list=input_resps_list,
lang_list=lang_list,
@ -472,6 +482,7 @@ class AR_NAR(Base):
if len_list is not None:
resps_list = self.forward_nar_masked(
phns_list=phns_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
task_list=task_list,
@ -544,7 +555,8 @@ class AR_NAR(Base):
if cfg_strength > 0:
null_inputs = super().inputs(
phns_list=null_text,
phns_list=null_text if phns_list is not None else None,
text_list=null_text if text_list is not None else None,
proms_list=null_prom,
resps_list=prev_list,
lang_list=lang_list,
@ -769,7 +781,8 @@ class AR_NAR(Base):
if cfg_strength > 0:
null_inputs = super().inputs(
phns_list=null_text,
phns_list=null_text if phns_list is not None else None,
text_list=null_text if text_list is not None else None,
proms_list=null_prom,
resps_list=resps_list,
lang_list=lang_list,
@ -908,7 +921,7 @@ class AR_NAR(Base):
batch_size = len(resps_list)
# implicitly set for training
if training is None and phns_list is not None and resps_list is not None:
if training is None and (phns_list is not None or text_list is not None) and resps_list is not None:
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
@ -930,7 +943,7 @@ class AR_NAR(Base):
)
# is NAR
if (len_list is not None or resps_list is not None) and phns_list is not None:
if (len_list is not None or resps_list is not None) and (phns_list is not None or text_list is not None):
return self.forward_nar(
task_list=task_list,

View File

@ -220,8 +220,16 @@ class AR_NAR_V2(Base_V2):
use_lora=None,
**sampling_kwargs,
):
device = phns_list[0].device
batch_size = len(phns_list)
# deduce batch_size
if phns_list:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list:
device = text_list[0].device
batch_size = len(text_list)
elif proms_list:
device = proms_list[0].device
batch_size = len(proms_list)
level = 0
if cfg.lora is not None:
@ -298,6 +306,7 @@ class AR_NAR_V2(Base_V2):
# setup inputs
inputs = super().inputs(
phns_list=phns_list,
text_list=text_list,
proms_list=proms_list,
resps_list=input_resps_list,
lang_list=lang_list,
@ -313,7 +322,8 @@ class AR_NAR_V2(Base_V2):
logits = output.logits
if cfg_strength > 0:
null_inputs = super().inputs(
phns_list=null_text,
phns_list=null_text if phns_list is not None else None,
text_list=null_text if text_list is not None else None,
proms_list=null_prom,
resps_list=input_resps_list,
lang_list=lang_list,
@ -507,7 +517,8 @@ class AR_NAR_V2(Base_V2):
if cfg_strength > 0:
null_inputs = super().inputs(
phns_list=null_text,
phns_list=null_text if phns_list is not None else None,
text_list=null_text if text_list is not None else None,
proms_list=null_prom,
resps_list=resps_list,
lang_list=lang_list,
@ -615,7 +626,7 @@ class AR_NAR_V2(Base_V2):
)
# is NAR
if (len_list is not None or resps_list is not None) and phns_list is not None:
if (len_list is not None or resps_list is not None) and (phns_list is not None or text_list is not None):
# to-do: verify this actually does return the input resps if theyre already filled
"""
if resps_list is not None:

0
vall_e/models/base.py Executable file → Normal file
View File

View File

@ -765,7 +765,7 @@ class Base_V2(nn.Module):
# needed, cringe
if task_type == "len":
batch[-1] = torch.cat( [ batch[-1], self.sep[None] ] )
batch[-1] = torch.cat( [ batch[-1], self.sep[None], self.sep[None] ] )
x_list.append( _join( batch, self.sep ) )
@ -1219,6 +1219,7 @@ class Base_V2(nn.Module):
if tasks[0] == "len":
# do duration prediction
logits_aux = self.len_decoder( output.logits )
print( logits_aux[0].shape, logits_aux[0] )
# it's more accurate this way
logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ]

0
vall_e/train.py Executable file → Normal file
View File

View File

@ -477,11 +477,9 @@ with ui:
with gr.Row():
layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="How to split the text into utterances.", value="sentences")
layout["inference_tts"]["inputs"]["context-history"] = gr.Slider(value=0, minimum=0, maximum=4, step=1, label="(Rolling) Context History", info="How many prior lines to serve as the context/prefix (0 to disable).")
"""
with gr.Row():
layout["inference_tts"]["inputs"]["no-phonemize"] = gr.Checkbox(label="No Phonemize", info="Use raw text rather than phonemize the text as the input prompt.")
layout["inference_tts"]["inputs"]["play"] = gr.Checkbox(label="Auto Play", info="Auto play on generation (using sounddevice).")
"""
with gr.Tab("Sampler Settings"):
with gr.Row():
layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR/NAR-len)", info="Adjusts the probabilities in the AR/NAR-len. (0 to greedy* sample)")