fixes for non-phonemized text input
This commit is contained in:
parent
476d87d4aa
commit
ef0fd0c8ac
|
@ -384,6 +384,7 @@ class AR_NAR(Base):
|
|||
# setup inputs
|
||||
inputs = super().inputs(
|
||||
phns_list=phns_list,
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -400,7 +401,8 @@ class AR_NAR(Base):
|
|||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
phns_list=null_text,
|
||||
phns_list=null_text if phns_list is not None else None,
|
||||
text_list=null_text if text_list is not None else None,
|
||||
proms_list=null_prom,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -472,6 +474,7 @@ class AR_NAR(Base):
|
|||
if len_list is not None:
|
||||
resps_list = self.forward_nar_masked(
|
||||
phns_list=phns_list,
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
task_list=task_list,
|
||||
|
@ -529,6 +532,7 @@ class AR_NAR(Base):
|
|||
|
||||
inputs = self.inputs(
|
||||
phns_list=phns_list,
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=prev_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -544,7 +548,8 @@ class AR_NAR(Base):
|
|||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
phns_list=null_text,
|
||||
phns_list=null_text if phns_list is not None else None,
|
||||
text_list=null_text if text_list is not None else None,
|
||||
proms_list=null_prom,
|
||||
resps_list=prev_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -769,7 +774,8 @@ class AR_NAR(Base):
|
|||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
phns_list=null_text,
|
||||
phns_list=null_text if phns_list is not None else None,
|
||||
text_list=null_text if text_list is not None else None,
|
||||
proms_list=null_prom,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
|
|
|
@ -220,8 +220,12 @@ class AR_NAR_V2(Base_V2):
|
|||
use_lora=None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
device = phns_list[0].device
|
||||
batch_size = len(phns_list)
|
||||
if phns_list is not None:
|
||||
device = phns_list[0].device
|
||||
batch_size = len(phns_list)
|
||||
elif text_list is not None:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
||||
level = 0
|
||||
if cfg.lora is not None:
|
||||
|
@ -298,6 +302,7 @@ class AR_NAR_V2(Base_V2):
|
|||
# setup inputs
|
||||
inputs = super().inputs(
|
||||
phns_list=phns_list,
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -313,7 +318,8 @@ class AR_NAR_V2(Base_V2):
|
|||
logits = output.logits
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
phns_list=null_text,
|
||||
phns_list=null_text if phns_list is not None else None,
|
||||
text_list=null_text if text_list is not None else None,
|
||||
proms_list=null_prom,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -507,7 +513,8 @@ class AR_NAR_V2(Base_V2):
|
|||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
phns_list=null_text,
|
||||
phns_list=null_text if phns_list is not None else None,
|
||||
text_list=null_text if text_list is not None else None,
|
||||
proms_list=null_prom,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -615,7 +622,7 @@ class AR_NAR_V2(Base_V2):
|
|||
)
|
||||
|
||||
# is NAR
|
||||
if (len_list is not None or resps_list is not None) and phns_list is not None:
|
||||
if (len_list is not None or resps_list is not None) and (phns_list is not None or text_list is not None):
|
||||
# to-do: verify this actually does return the input resps if theyre already filled
|
||||
"""
|
||||
if resps_list is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user