diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 3d34819..cac48f6 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -173,7 +173,7 @@ class AR_NAR(Base): resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 # only apply stop token for RVQ level 0 - if quant_level <= 0 and timesteps[i] is None: + if quant_level <= 0 and timesteps[i] is None and not self.parallel_decoding: # append stop tokens for AR if task not in text_task: resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 818f06a..53c746b 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1273,20 +1273,21 @@ class Base(nn.Module): input if quant_level == 0 else input[:, :quant_level] ) - if self.version < 7 or not self.parallel_decoding: + if self.version < 7: # or not self.parallel_decoding: return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level offset = 0, ) - """ + if not self.parallel_decoding: return self.proms_emb( - input if input.dim() == 1 else input[:, :quant_level+1], - quant_level = quant_level, + input, + quant_level = 0 if input.dim() == 1 else input.shape[-1], offset = 0, ) """ + """ return self.proms_emb( input )