diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 46d0661..7e209db 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -117,10 +117,11 @@ class AR(Base): device = text_list[0].device batch_size = len(text_list) - resps_list: list[Tensor] = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ] + sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ] stopped = torch.zeros(batch_size, device=device).bool() state = {} if cfg.inference.recurrent_forward else None + sampling_beam_width_use_logs = True scores = [ 1.0 ] * sampling_beam_width @@ -129,12 +130,12 @@ class AR(Base): # get next in sequence for n in trange(max_steps // max(1, self.recurrent_chunk_size)): - + resps_list = self._unsqueeze_list(sequence_list) logits = super().forward( text_list=text_list, proms_list=proms_list, - resps_list=self._unsqueeze_list(resps_list), - quant_levels=None, + resps_list=resps_list, + state=state ) @@ -173,8 +174,7 @@ class AR(Base): for i, ri in enumerate(r): if self.stop_token in ri: stopped[i] = True - - resps_list[i] = torch.cat([resps_list[i], ri]) + sequence_list[i] = torch.cat([sequence_list[i], ri]) # stop token found stopped |= r == self.stop_token @@ -191,10 +191,9 @@ class AR(Base): sequence_list = [sequence_list[best_idx]] - res = [self._prune(r) for r in resps_list] if self.interleave: - res = [self._deinterleave(r) for r in res] - return res + sequence_list = [self._deinterleave(r) for r in sequence_list] + return [self._prune(r) for r in sequence_list] def example_usage(): diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 69377dd..9d02f4d 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -150,7 +150,7 @@ class MultiEmbedding(nn.Module): """ def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False): - super().__init__(max_n_levels, token_dim) + super().__init__() self.monolithic = monolithic self.max_n_levels = max_n_levels self.n_tokens = n_tokens diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index b0e60c6..e9f49a1 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -69,6 +69,7 @@ class NAR(Base): sampling_repetition_penalty: float = 1.0, sampling_repetition_penalty_decay: float = 0.0, sampling_length_penalty: float = 0.0, + sampling_beam_width: int = 0, ): """ Args: @@ -129,9 +130,9 @@ class NAR(Base): resps_list = super().sample( logits=logits, - resps_list=resps_list, + resps_list=prev_list, quant_levels=quant_levels, - + temperature=sampling_temperature, top_p=sampling_top_p, top_k=sampling_top_k,