diff --git a/docs/models_v2.md b/docs/models_v2.md index 91eabbb..7ab3996 100644 --- a/docs/models_v2.md +++ b/docs/models_v2.md @@ -145,6 +145,9 @@ These settings should be avoided: To be evaluated thoroughly. * The smaller model seems to have hit its capacity limit, while the larger model is slowly improving (although objective metrics are not noted). +* The model seems pretty quick, even for the large model. +* The smaller model seems small enough for CPU-only inferencing + * Despite its poor zero-shot performance, it could be perfectly fine for finetuning. At a glance, compared to the prior model setup, this implementation allows for the model to better represent speech as it's able to see the entire signal and account for it in its latent space, rather than only specific levels of it. diff --git a/vall_e/__init__.py b/vall_e/__init__.py old mode 100755 new mode 100644 diff --git a/vall_e/__main__.py b/vall_e/__main__.py old mode 100755 new mode 100644 diff --git a/vall_e/config.py b/vall_e/config.py old mode 100755 new mode 100644 diff --git a/vall_e/data.py b/vall_e/data.py old mode 100755 new mode 100644 diff --git a/vall_e/export.py b/vall_e/export.py old mode 100755 new mode 100644 diff --git a/vall_e/inference.py b/vall_e/inference.py old mode 100755 new mode 100644 diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py old mode 100755 new mode 100644 diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 4aac53a..fcc72a9 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -249,8 +249,16 @@ class AR_NAR(Base): use_lora=None, **sampling_kwargs, ): - device = phns_list[0].device - batch_size = len(phns_list) + # deduce batch_size + if phns_list: + device = phns_list[0].device + batch_size = len(phns_list) + elif text_list: + device = text_list[0].device + batch_size = len(text_list) + elif proms_list: + device = proms_list[0].device + batch_size = len(proms_list) if quant_levels is None: level = 0 @@ -384,6 +392,7 @@ class AR_NAR(Base): # setup inputs inputs = super().inputs( phns_list=phns_list, + text_list=text_list, proms_list=proms_list, resps_list=input_resps_list, lang_list=lang_list, @@ -400,7 +409,8 @@ class AR_NAR(Base): if cfg_strength > 0: null_inputs = super().inputs( - phns_list=null_text, + phns_list=null_text if phns_list is not None else None, + text_list=null_text if text_list is not None else None, proms_list=null_prom, resps_list=input_resps_list, lang_list=lang_list, @@ -472,6 +482,7 @@ class AR_NAR(Base): if len_list is not None: resps_list = self.forward_nar_masked( phns_list=phns_list, + text_list=text_list, proms_list=proms_list, resps_list=resps_list, task_list=task_list, @@ -544,7 +555,8 @@ class AR_NAR(Base): if cfg_strength > 0: null_inputs = super().inputs( - phns_list=null_text, + phns_list=null_text if phns_list is not None else None, + text_list=null_text if text_list is not None else None, proms_list=null_prom, resps_list=prev_list, lang_list=lang_list, @@ -769,7 +781,8 @@ class AR_NAR(Base): if cfg_strength > 0: null_inputs = super().inputs( - phns_list=null_text, + phns_list=null_text if phns_list is not None else None, + text_list=null_text if text_list is not None else None, proms_list=null_prom, resps_list=resps_list, lang_list=lang_list, @@ -908,7 +921,7 @@ class AR_NAR(Base): batch_size = len(resps_list) # implicitly set for training - if training is None and phns_list is not None and resps_list is not None: + if training is None and (phns_list is not None or text_list is not None) and resps_list is not None: n_levels_set = {r.shape[-1] for r in resps_list} n_levels = next(iter(n_levels_set)) @@ -930,7 +943,7 @@ class AR_NAR(Base): ) # is NAR - if (len_list is not None or resps_list is not None) and phns_list is not None: + if (len_list is not None or resps_list is not None) and (phns_list is not None or text_list is not None): return self.forward_nar( task_list=task_list, diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index fc494b8..c0549f0 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -220,8 +220,16 @@ class AR_NAR_V2(Base_V2): use_lora=None, **sampling_kwargs, ): - device = phns_list[0].device - batch_size = len(phns_list) + # deduce batch_size + if phns_list: + device = phns_list[0].device + batch_size = len(phns_list) + elif text_list: + device = text_list[0].device + batch_size = len(text_list) + elif proms_list: + device = proms_list[0].device + batch_size = len(proms_list) level = 0 if cfg.lora is not None: @@ -298,6 +306,7 @@ class AR_NAR_V2(Base_V2): # setup inputs inputs = super().inputs( phns_list=phns_list, + text_list=text_list, proms_list=proms_list, resps_list=input_resps_list, lang_list=lang_list, @@ -313,7 +322,8 @@ class AR_NAR_V2(Base_V2): logits = output.logits if cfg_strength > 0: null_inputs = super().inputs( - phns_list=null_text, + phns_list=null_text if phns_list is not None else None, + text_list=null_text if text_list is not None else None, proms_list=null_prom, resps_list=input_resps_list, lang_list=lang_list, @@ -507,7 +517,8 @@ class AR_NAR_V2(Base_V2): if cfg_strength > 0: null_inputs = super().inputs( - phns_list=null_text, + phns_list=null_text if phns_list is not None else None, + text_list=null_text if text_list is not None else None, proms_list=null_prom, resps_list=resps_list, lang_list=lang_list, @@ -615,7 +626,7 @@ class AR_NAR_V2(Base_V2): ) # is NAR - if (len_list is not None or resps_list is not None) and phns_list is not None: + if (len_list is not None or resps_list is not None) and (phns_list is not None or text_list is not None): # to-do: verify this actually does return the input resps if theyre already filled """ if resps_list is not None: diff --git a/vall_e/models/base.py b/vall_e/models/base.py old mode 100755 new mode 100644 diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 258205f..b7ddf57 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -765,7 +765,7 @@ class Base_V2(nn.Module): # needed, cringe if task_type == "len": - batch[-1] = torch.cat( [ batch[-1], self.sep[None] ] ) + batch[-1] = torch.cat( [ batch[-1], self.sep[None], self.sep[None] ] ) x_list.append( _join( batch, self.sep ) ) @@ -1219,6 +1219,7 @@ class Base_V2(nn.Module): if tasks[0] == "len": # do duration prediction logits_aux = self.len_decoder( output.logits ) + print( logits_aux[0].shape, logits_aux[0] ) # it's more accurate this way logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ] diff --git a/vall_e/train.py b/vall_e/train.py old mode 100755 new mode 100644 diff --git a/vall_e/webui.py b/vall_e/webui.py index 8b730e8..f2694f6 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -477,11 +477,9 @@ with ui: with gr.Row(): layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="How to split the text into utterances.", value="sentences") layout["inference_tts"]["inputs"]["context-history"] = gr.Slider(value=0, minimum=0, maximum=4, step=1, label="(Rolling) Context History", info="How many prior lines to serve as the context/prefix (0 to disable).") - """ with gr.Row(): layout["inference_tts"]["inputs"]["no-phonemize"] = gr.Checkbox(label="No Phonemize", info="Use raw text rather than phonemize the text as the input prompt.") layout["inference_tts"]["inputs"]["play"] = gr.Checkbox(label="Auto Play", info="Auto play on generation (using sounddevice).") - """ with gr.Tab("Sampler Settings"): with gr.Row(): layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR/NAR-len)", info="Adjusts the probabilities in the AR/NAR-len. (0 to greedy* sample)")