diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index f3716eb..d58ef26 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -102,6 +102,7 @@ class AR(Base): text_list: list[Tensor], proms_list: list[Tensor], resps_list: list[Tensor] | None = None, + lang_list: list[Tensor] | None = None, max_steps: int = 1000, max_resp_context: int = -1, @@ -128,6 +129,7 @@ class AR(Base): proms_list=proms_list, resps_list=self._unsqueeze_list(resps_list), targ_list=resps_list, + lang_list=lang_list, quant_levels=None, ) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 4bd2679..e9cefab 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -134,6 +134,7 @@ class AR_NAR(Base): proms_list=proms_list, resps_list=resps_list, targ_list=targ_list, + lang_list=lang_list, quant_levels=quant_levels, ) # is NAR @@ -153,6 +154,7 @@ class AR_NAR(Base): text_list=text_list, proms_list=proms_list, resps_list=prev_list, + lang_list=lang_list, quant_levels=quant_levels, ) @@ -204,7 +206,7 @@ class AR_NAR(Base): text_list=text_list, proms_list=proms_list, resps_list=resps_list, - + lang_list=lang_list, state=recurrent_state ) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 9e1ed00..c90ddf9 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -304,7 +304,7 @@ class Base(nn.Module): batch_size = len(text_list) if self.langs_emb is None: - langs_list = None + lang_list = None x_list = self._samplewise_merge_tensors( self.text_emb(text_list), diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 91d8d34..bb18f3a 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -74,6 +74,7 @@ class NAR(Base): text_list: list[Tensor], proms_list: list[Tensor], resps_list: list[Tensor], + lang_list: list[Tensor] | None = None, max_levels: int = 0, sampling_temperature: float = 0.2, sampling_min_temperature: float = -1.0, @@ -118,6 +119,7 @@ class NAR(Base): proms_list=proms_list, resps_list=prev_list, targ_list=targ_list, + lang_list=lang_list, quant_levels=quant_levels, ) @@ -139,6 +141,7 @@ class NAR(Base): text_list=text_list, proms_list=proms_list, resps_list=prev_list, + lang_list=lang_list, quant_levels=quant_levels, ) diff --git a/vall_e/train.py b/vall_e/train.py index 91e63cd..e7821a5 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -26,7 +26,8 @@ def train_feeder(engine, batch): engine( text_list=batch["text"], proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level - resps_list=batch["resps"] + resps_list=batch["resps"], + lang_list=batch["lang"], ) losses = engine.gather_attribute("loss")