fixes for non-phonemized text input

This commit is contained in:
mrq 2025-03-25 22:02:14 -05:00
parent 476d87d4aa
commit ef0fd0c8ac
2 changed files with 21 additions and 8 deletions

View File

@ -384,6 +384,7 @@ class AR_NAR(Base):
# setup inputs # setup inputs
inputs = super().inputs( inputs = super().inputs(
phns_list=phns_list, phns_list=phns_list,
text_list=text_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=input_resps_list, resps_list=input_resps_list,
lang_list=lang_list, lang_list=lang_list,
@ -400,7 +401,8 @@ class AR_NAR(Base):
if cfg_strength > 0: if cfg_strength > 0:
null_inputs = super().inputs( null_inputs = super().inputs(
phns_list=null_text, phns_list=null_text if phns_list is not None else None,
text_list=null_text if text_list is not None else None,
proms_list=null_prom, proms_list=null_prom,
resps_list=input_resps_list, resps_list=input_resps_list,
lang_list=lang_list, lang_list=lang_list,
@ -472,6 +474,7 @@ class AR_NAR(Base):
if len_list is not None: if len_list is not None:
resps_list = self.forward_nar_masked( resps_list = self.forward_nar_masked(
phns_list=phns_list, phns_list=phns_list,
text_list=text_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=resps_list, resps_list=resps_list,
task_list=task_list, task_list=task_list,
@ -529,6 +532,7 @@ class AR_NAR(Base):
inputs = self.inputs( inputs = self.inputs(
phns_list=phns_list, phns_list=phns_list,
text_list=text_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=prev_list, resps_list=prev_list,
lang_list=lang_list, lang_list=lang_list,
@ -544,7 +548,8 @@ class AR_NAR(Base):
if cfg_strength > 0: if cfg_strength > 0:
null_inputs = super().inputs( null_inputs = super().inputs(
phns_list=null_text, phns_list=null_text if phns_list is not None else None,
text_list=null_text if text_list is not None else None,
proms_list=null_prom, proms_list=null_prom,
resps_list=prev_list, resps_list=prev_list,
lang_list=lang_list, lang_list=lang_list,
@ -769,7 +774,8 @@ class AR_NAR(Base):
if cfg_strength > 0: if cfg_strength > 0:
null_inputs = super().inputs( null_inputs = super().inputs(
phns_list=null_text, phns_list=null_text if phns_list is not None else None,
text_list=null_text if text_list is not None else None,
proms_list=null_prom, proms_list=null_prom,
resps_list=resps_list, resps_list=resps_list,
lang_list=lang_list, lang_list=lang_list,

View File

@ -220,8 +220,12 @@ class AR_NAR_V2(Base_V2):
use_lora=None, use_lora=None,
**sampling_kwargs, **sampling_kwargs,
): ):
device = phns_list[0].device if phns_list is not None:
batch_size = len(phns_list) device = phns_list[0].device
batch_size = len(phns_list)
elif text_list is not None:
device = text_list[0].device
batch_size = len(text_list)
level = 0 level = 0
if cfg.lora is not None: if cfg.lora is not None:
@ -298,6 +302,7 @@ class AR_NAR_V2(Base_V2):
# setup inputs # setup inputs
inputs = super().inputs( inputs = super().inputs(
phns_list=phns_list, phns_list=phns_list,
text_list=text_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=input_resps_list, resps_list=input_resps_list,
lang_list=lang_list, lang_list=lang_list,
@ -313,7 +318,8 @@ class AR_NAR_V2(Base_V2):
logits = output.logits logits = output.logits
if cfg_strength > 0: if cfg_strength > 0:
null_inputs = super().inputs( null_inputs = super().inputs(
phns_list=null_text, phns_list=null_text if phns_list is not None else None,
text_list=null_text if text_list is not None else None,
proms_list=null_prom, proms_list=null_prom,
resps_list=input_resps_list, resps_list=input_resps_list,
lang_list=lang_list, lang_list=lang_list,
@ -507,7 +513,8 @@ class AR_NAR_V2(Base_V2):
if cfg_strength > 0: if cfg_strength > 0:
null_inputs = super().inputs( null_inputs = super().inputs(
phns_list=null_text, phns_list=null_text if phns_list is not None else None,
text_list=null_text if text_list is not None else None,
proms_list=null_prom, proms_list=null_prom,
resps_list=resps_list, resps_list=resps_list,
lang_list=lang_list, lang_list=lang_list,
@ -615,7 +622,7 @@ class AR_NAR_V2(Base_V2):
) )
# is NAR # is NAR
if (len_list is not None or resps_list is not None) and phns_list is not None: if (len_list is not None or resps_list is not None) and (phns_list is not None or text_list is not None):
# to-do: verify this actually does return the input resps if theyre already filled # to-do: verify this actually does return the input resps if theyre already filled
""" """
if resps_list is not None: if resps_list is not None: