this seems preferable

This commit is contained in:
mrq 2025-02-12 00:36:50 -06:00
parent 04fef5dad5
commit 4b31f5c808
2 changed files with 6 additions and 5 deletions

View File

@ -173,7 +173,7 @@ class AR_NAR(Base):
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
# only apply stop token for RVQ level 0
if quant_level <= 0 and timesteps[i] is None:
if quant_level <= 0 and timesteps[i] is None and not self.parallel_decoding:
# append stop tokens for AR
if task not in text_task:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])

View File

@ -1273,20 +1273,21 @@ class Base(nn.Module):
input if quant_level == 0 else input[:, :quant_level]
)
if self.version < 7 or not self.parallel_decoding:
if self.version < 7: # or not self.parallel_decoding:
return self.proms_emb(
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
offset = 0,
)
"""
if not self.parallel_decoding:
return self.proms_emb(
input if input.dim() == 1 else input[:, :quant_level+1],
quant_level = quant_level,
input,
quant_level = 0 if input.dim() == 1 else input.shape[-1],
offset = 0,
)
"""
"""
return self.proms_emb( input )