diff --git a/vall_e/vall_e/ar.py b/vall_e/vall_e/ar.py index bda1bd4..4409745 100644 --- a/vall_e/vall_e/ar.py +++ b/vall_e/vall_e/ar.py @@ -23,29 +23,33 @@ class AR(Base): indices = (l == self.stop_token).nonzero() if len(indices) == 0: return l - return l[: indices[0].item()] + return l[: indices.min().item()] def forward( self, text_list: list[Tensor], proms_list: list[Tensor], - resp_list: list[Tensor], + resp_list: list[Tensor] | None = None, + max_steps: int = 1000, ): - return super().forward( - text_list, - proms_list, - resp_list, - resp_list, - quant_level=0, - shift_targ_list=True, - return_all_resp=False, - ) + if resp_list is not None: + return super().forward( + text_list, + proms_list, + resp_list, + resp_list, + quant_level=0, + shift_targ_list=True, + return_all_resp=False, + ) + else: + return self._generate(text_list, proms_list, max_steps) - def generate( + def _generate( self, text_list: list[Tensor], proms_list: list[Tensor], - max_steps: int = 1000, + max_steps: int, ): device = text_list[0].device resp_list: list[Tensor] = [ @@ -93,11 +97,7 @@ def example_usage(): qnt.to(device), ] - out = model.generate( - text_list, - proms_list, - max_steps=200, - ) + out = model(text_list, proms_list, max_steps=200) print(out) @@ -114,7 +114,7 @@ def example_usage(): if i % 20 == 0: print(f"iter={i}, {losses}.") - out = model.generate(text_list, proms_list, max_steps=200) + out = model(text_list, proms_list, max_steps=200) print(qnt) print(out)