fixes for non-phonemized text input
This commit is contained in:
parent
476d87d4aa
commit
ef0fd0c8ac
|
@ -384,6 +384,7 @@ class AR_NAR(Base):
|
||||||
# setup inputs
|
# setup inputs
|
||||||
inputs = super().inputs(
|
inputs = super().inputs(
|
||||||
phns_list=phns_list,
|
phns_list=phns_list,
|
||||||
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=input_resps_list,
|
resps_list=input_resps_list,
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
|
@ -400,7 +401,8 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
if cfg_strength > 0:
|
if cfg_strength > 0:
|
||||||
null_inputs = super().inputs(
|
null_inputs = super().inputs(
|
||||||
phns_list=null_text,
|
phns_list=null_text if phns_list is not None else None,
|
||||||
|
text_list=null_text if text_list is not None else None,
|
||||||
proms_list=null_prom,
|
proms_list=null_prom,
|
||||||
resps_list=input_resps_list,
|
resps_list=input_resps_list,
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
|
@ -472,6 +474,7 @@ class AR_NAR(Base):
|
||||||
if len_list is not None:
|
if len_list is not None:
|
||||||
resps_list = self.forward_nar_masked(
|
resps_list = self.forward_nar_masked(
|
||||||
phns_list=phns_list,
|
phns_list=phns_list,
|
||||||
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
task_list=task_list,
|
task_list=task_list,
|
||||||
|
@ -529,6 +532,7 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
phns_list=phns_list,
|
phns_list=phns_list,
|
||||||
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=prev_list,
|
resps_list=prev_list,
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
|
@ -544,7 +548,8 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
if cfg_strength > 0:
|
if cfg_strength > 0:
|
||||||
null_inputs = super().inputs(
|
null_inputs = super().inputs(
|
||||||
phns_list=null_text,
|
phns_list=null_text if phns_list is not None else None,
|
||||||
|
text_list=null_text if text_list is not None else None,
|
||||||
proms_list=null_prom,
|
proms_list=null_prom,
|
||||||
resps_list=prev_list,
|
resps_list=prev_list,
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
|
@ -769,7 +774,8 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
if cfg_strength > 0:
|
if cfg_strength > 0:
|
||||||
null_inputs = super().inputs(
|
null_inputs = super().inputs(
|
||||||
phns_list=null_text,
|
phns_list=null_text if phns_list is not None else None,
|
||||||
|
text_list=null_text if text_list is not None else None,
|
||||||
proms_list=null_prom,
|
proms_list=null_prom,
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
|
|
|
@ -220,8 +220,12 @@ class AR_NAR_V2(Base_V2):
|
||||||
use_lora=None,
|
use_lora=None,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
):
|
):
|
||||||
device = phns_list[0].device
|
if phns_list is not None:
|
||||||
batch_size = len(phns_list)
|
device = phns_list[0].device
|
||||||
|
batch_size = len(phns_list)
|
||||||
|
elif text_list is not None:
|
||||||
|
device = text_list[0].device
|
||||||
|
batch_size = len(text_list)
|
||||||
|
|
||||||
level = 0
|
level = 0
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
|
@ -298,6 +302,7 @@ class AR_NAR_V2(Base_V2):
|
||||||
# setup inputs
|
# setup inputs
|
||||||
inputs = super().inputs(
|
inputs = super().inputs(
|
||||||
phns_list=phns_list,
|
phns_list=phns_list,
|
||||||
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=input_resps_list,
|
resps_list=input_resps_list,
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
|
@ -313,7 +318,8 @@ class AR_NAR_V2(Base_V2):
|
||||||
logits = output.logits
|
logits = output.logits
|
||||||
if cfg_strength > 0:
|
if cfg_strength > 0:
|
||||||
null_inputs = super().inputs(
|
null_inputs = super().inputs(
|
||||||
phns_list=null_text,
|
phns_list=null_text if phns_list is not None else None,
|
||||||
|
text_list=null_text if text_list is not None else None,
|
||||||
proms_list=null_prom,
|
proms_list=null_prom,
|
||||||
resps_list=input_resps_list,
|
resps_list=input_resps_list,
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
|
@ -507,7 +513,8 @@ class AR_NAR_V2(Base_V2):
|
||||||
|
|
||||||
if cfg_strength > 0:
|
if cfg_strength > 0:
|
||||||
null_inputs = super().inputs(
|
null_inputs = super().inputs(
|
||||||
phns_list=null_text,
|
phns_list=null_text if phns_list is not None else None,
|
||||||
|
text_list=null_text if text_list is not None else None,
|
||||||
proms_list=null_prom,
|
proms_list=null_prom,
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
|
@ -615,7 +622,7 @@ class AR_NAR_V2(Base_V2):
|
||||||
)
|
)
|
||||||
|
|
||||||
# is NAR
|
# is NAR
|
||||||
if (len_list is not None or resps_list is not None) and phns_list is not None:
|
if (len_list is not None or resps_list is not None) and (phns_list is not None or text_list is not None):
|
||||||
# to-do: verify this actually does return the input resps if theyre already filled
|
# to-do: verify this actually does return the input resps if theyre already filled
|
||||||
"""
|
"""
|
||||||
if resps_list is not None:
|
if resps_list is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user