nothing could go wrong part 2 (reverted and rewrote commits since there was a nasty regression)
This commit is contained in:
parent
aa8b32d97e
commit
8641c87611
|
@ -145,6 +145,9 @@ These settings should be avoided:
|
|||
|
||||
To be evaluated thoroughly.
|
||||
* The smaller model seems to have hit its capacity limit, while the larger model is slowly improving (although objective metrics are not noted).
|
||||
* The model seems pretty quick, even for the large model.
|
||||
* The smaller model seems small enough for CPU-only inferencing
|
||||
* Despite its poor zero-shot performance, it could be perfectly fine for finetuning.
|
||||
|
||||
At a glance, compared to the prior model setup, this implementation allows for the model to better represent speech as it's able to see the entire signal and account for it in its latent space, rather than only specific levels of it.
|
||||
|
||||
|
|
0
vall_e/__init__.py
Executable file → Normal file
0
vall_e/__init__.py
Executable file → Normal file
0
vall_e/__main__.py
Executable file → Normal file
0
vall_e/__main__.py
Executable file → Normal file
0
vall_e/config.py
Executable file → Normal file
0
vall_e/config.py
Executable file → Normal file
0
vall_e/data.py
Executable file → Normal file
0
vall_e/data.py
Executable file → Normal file
0
vall_e/export.py
Executable file → Normal file
0
vall_e/export.py
Executable file → Normal file
0
vall_e/inference.py
Executable file → Normal file
0
vall_e/inference.py
Executable file → Normal file
0
vall_e/models/__init__.py
Executable file → Normal file
0
vall_e/models/__init__.py
Executable file → Normal file
|
@ -249,8 +249,16 @@ class AR_NAR(Base):
|
|||
use_lora=None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
device = phns_list[0].device
|
||||
batch_size = len(phns_list)
|
||||
# deduce batch_size
|
||||
if phns_list:
|
||||
device = phns_list[0].device
|
||||
batch_size = len(phns_list)
|
||||
elif text_list:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
elif proms_list:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
|
||||
if quant_levels is None:
|
||||
level = 0
|
||||
|
@ -384,6 +392,7 @@ class AR_NAR(Base):
|
|||
# setup inputs
|
||||
inputs = super().inputs(
|
||||
phns_list=phns_list,
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -400,7 +409,8 @@ class AR_NAR(Base):
|
|||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
phns_list=null_text,
|
||||
phns_list=null_text if phns_list is not None else None,
|
||||
text_list=null_text if text_list is not None else None,
|
||||
proms_list=null_prom,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -472,6 +482,7 @@ class AR_NAR(Base):
|
|||
if len_list is not None:
|
||||
resps_list = self.forward_nar_masked(
|
||||
phns_list=phns_list,
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
task_list=task_list,
|
||||
|
@ -544,7 +555,8 @@ class AR_NAR(Base):
|
|||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
phns_list=null_text,
|
||||
phns_list=null_text if phns_list is not None else None,
|
||||
text_list=null_text if text_list is not None else None,
|
||||
proms_list=null_prom,
|
||||
resps_list=prev_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -769,7 +781,8 @@ class AR_NAR(Base):
|
|||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
phns_list=null_text,
|
||||
phns_list=null_text if phns_list is not None else None,
|
||||
text_list=null_text if text_list is not None else None,
|
||||
proms_list=null_prom,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -908,7 +921,7 @@ class AR_NAR(Base):
|
|||
batch_size = len(resps_list)
|
||||
|
||||
# implicitly set for training
|
||||
if training is None and phns_list is not None and resps_list is not None:
|
||||
if training is None and (phns_list is not None or text_list is not None) and resps_list is not None:
|
||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
||||
|
@ -930,7 +943,7 @@ class AR_NAR(Base):
|
|||
)
|
||||
|
||||
# is NAR
|
||||
if (len_list is not None or resps_list is not None) and phns_list is not None:
|
||||
if (len_list is not None or resps_list is not None) and (phns_list is not None or text_list is not None):
|
||||
return self.forward_nar(
|
||||
task_list=task_list,
|
||||
|
||||
|
|
|
@ -220,8 +220,16 @@ class AR_NAR_V2(Base_V2):
|
|||
use_lora=None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
device = phns_list[0].device
|
||||
batch_size = len(phns_list)
|
||||
# deduce batch_size
|
||||
if phns_list:
|
||||
device = phns_list[0].device
|
||||
batch_size = len(phns_list)
|
||||
elif text_list:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
elif proms_list:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
|
||||
level = 0
|
||||
if cfg.lora is not None:
|
||||
|
@ -298,6 +306,7 @@ class AR_NAR_V2(Base_V2):
|
|||
# setup inputs
|
||||
inputs = super().inputs(
|
||||
phns_list=phns_list,
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -313,7 +322,8 @@ class AR_NAR_V2(Base_V2):
|
|||
logits = output.logits
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
phns_list=null_text,
|
||||
phns_list=null_text if phns_list is not None else None,
|
||||
text_list=null_text if text_list is not None else None,
|
||||
proms_list=null_prom,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -507,7 +517,8 @@ class AR_NAR_V2(Base_V2):
|
|||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
phns_list=null_text,
|
||||
phns_list=null_text if phns_list is not None else None,
|
||||
text_list=null_text if text_list is not None else None,
|
||||
proms_list=null_prom,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -615,7 +626,7 @@ class AR_NAR_V2(Base_V2):
|
|||
)
|
||||
|
||||
# is NAR
|
||||
if (len_list is not None or resps_list is not None) and phns_list is not None:
|
||||
if (len_list is not None or resps_list is not None) and (phns_list is not None or text_list is not None):
|
||||
# to-do: verify this actually does return the input resps if theyre already filled
|
||||
"""
|
||||
if resps_list is not None:
|
||||
|
|
0
vall_e/models/base.py
Executable file → Normal file
0
vall_e/models/base.py
Executable file → Normal file
|
@ -765,7 +765,7 @@ class Base_V2(nn.Module):
|
|||
|
||||
# needed, cringe
|
||||
if task_type == "len":
|
||||
batch[-1] = torch.cat( [ batch[-1], self.sep[None] ] )
|
||||
batch[-1] = torch.cat( [ batch[-1], self.sep[None], self.sep[None] ] )
|
||||
|
||||
x_list.append( _join( batch, self.sep ) )
|
||||
|
||||
|
@ -1219,6 +1219,7 @@ class Base_V2(nn.Module):
|
|||
if tasks[0] == "len":
|
||||
# do duration prediction
|
||||
logits_aux = self.len_decoder( output.logits )
|
||||
print( logits_aux[0].shape, logits_aux[0] )
|
||||
# it's more accurate this way
|
||||
logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
|
||||
|
|
0
vall_e/train.py
Executable file → Normal file
0
vall_e/train.py
Executable file → Normal file
|
@ -477,11 +477,9 @@ with ui:
|
|||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="How to split the text into utterances.", value="sentences")
|
||||
layout["inference_tts"]["inputs"]["context-history"] = gr.Slider(value=0, minimum=0, maximum=4, step=1, label="(Rolling) Context History", info="How many prior lines to serve as the context/prefix (0 to disable).")
|
||||
"""
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["no-phonemize"] = gr.Checkbox(label="No Phonemize", info="Use raw text rather than phonemize the text as the input prompt.")
|
||||
layout["inference_tts"]["inputs"]["play"] = gr.Checkbox(label="Auto Play", info="Auto play on generation (using sounddevice).")
|
||||
"""
|
||||
with gr.Tab("Sampler Settings"):
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR/NAR-len)", info="Adjusts the probabilities in the AR/NAR-len. (0 to greedy* sample)")
|
||||
|
|
Loading…
Reference in New Issue
Block a user