From ef0fd0c8acac7ae94e4575beaa4db7c397b9076b Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 25 Mar 2025 22:02:14 -0500 Subject: [PATCH] fixes for non-phonemized text input --- vall_e/models/ar_nar.py | 12 +++++++++--- vall_e/models/ar_nar_v2.py | 17 ++++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 4aac53a..ad77598 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -384,6 +384,7 @@ class AR_NAR(Base): # setup inputs inputs = super().inputs( phns_list=phns_list, + text_list=text_list, proms_list=proms_list, resps_list=input_resps_list, lang_list=lang_list, @@ -400,7 +401,8 @@ class AR_NAR(Base): if cfg_strength > 0: null_inputs = super().inputs( - phns_list=null_text, + phns_list=null_text if phns_list is not None else None, + text_list=null_text if text_list is not None else None, proms_list=null_prom, resps_list=input_resps_list, lang_list=lang_list, @@ -472,6 +474,7 @@ class AR_NAR(Base): if len_list is not None: resps_list = self.forward_nar_masked( phns_list=phns_list, + text_list=text_list, proms_list=proms_list, resps_list=resps_list, task_list=task_list, @@ -529,6 +532,7 @@ class AR_NAR(Base): inputs = self.inputs( phns_list=phns_list, + text_list=text_list, proms_list=proms_list, resps_list=prev_list, lang_list=lang_list, @@ -544,7 +548,8 @@ class AR_NAR(Base): if cfg_strength > 0: null_inputs = super().inputs( - phns_list=null_text, + phns_list=null_text if phns_list is not None else None, + text_list=null_text if text_list is not None else None, proms_list=null_prom, resps_list=prev_list, lang_list=lang_list, @@ -769,7 +774,8 @@ class AR_NAR(Base): if cfg_strength > 0: null_inputs = super().inputs( - phns_list=null_text, + phns_list=null_text if phns_list is not None else None, + text_list=null_text if text_list is not None else None, proms_list=null_prom, resps_list=resps_list, lang_list=lang_list, diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index fc494b8..27a0a03 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -220,8 +220,12 @@ class AR_NAR_V2(Base_V2): use_lora=None, **sampling_kwargs, ): - device = phns_list[0].device - batch_size = len(phns_list) + if phns_list is not None: + device = phns_list[0].device + batch_size = len(phns_list) + elif text_list is not None: + device = text_list[0].device + batch_size = len(text_list) level = 0 if cfg.lora is not None: @@ -298,6 +302,7 @@ class AR_NAR_V2(Base_V2): # setup inputs inputs = super().inputs( phns_list=phns_list, + text_list=text_list, proms_list=proms_list, resps_list=input_resps_list, lang_list=lang_list, @@ -313,7 +318,8 @@ class AR_NAR_V2(Base_V2): logits = output.logits if cfg_strength > 0: null_inputs = super().inputs( - phns_list=null_text, + phns_list=null_text if phns_list is not None else None, + text_list=null_text if text_list is not None else None, proms_list=null_prom, resps_list=input_resps_list, lang_list=lang_list, @@ -507,7 +513,8 @@ class AR_NAR_V2(Base_V2): if cfg_strength > 0: null_inputs = super().inputs( - phns_list=null_text, + phns_list=null_text if phns_list is not None else None, + text_list=null_text if text_list is not None else None, proms_list=null_prom, resps_list=resps_list, lang_list=lang_list, @@ -615,7 +622,7 @@ class AR_NAR_V2(Base_V2): ) # is NAR - if (len_list is not None or resps_list is not None) and phns_list is not None: + if (len_list is not None or resps_list is not None) and (phns_list is not None or text_list is not None): # to-do: verify this actually does return the input resps if theyre already filled """ if resps_list is not None: