this seems preferable
This commit is contained in:
parent
04fef5dad5
commit
4b31f5c808
|
@ -173,7 +173,7 @@ class AR_NAR(Base):
|
|||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
if quant_level <= 0 and timesteps[i] is None:
|
||||
if quant_level <= 0 and timesteps[i] is None and not self.parallel_decoding:
|
||||
# append stop tokens for AR
|
||||
if task not in text_task:
|
||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||
|
|
|
@ -1273,20 +1273,21 @@ class Base(nn.Module):
|
|||
input if quant_level == 0 else input[:, :quant_level]
|
||||
)
|
||||
|
||||
if self.version < 7 or not self.parallel_decoding:
|
||||
if self.version < 7: # or not self.parallel_decoding:
|
||||
return self.proms_emb(
|
||||
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
|
||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||
offset = 0,
|
||||
)
|
||||
"""
|
||||
|
||||
if not self.parallel_decoding:
|
||||
return self.proms_emb(
|
||||
input if input.dim() == 1 else input[:, :quant_level+1],
|
||||
quant_level = quant_level,
|
||||
input,
|
||||
quant_level = 0 if input.dim() == 1 else input.shape[-1],
|
||||
offset = 0,
|
||||
)
|
||||
"""
|
||||
"""
|
||||
|
||||
return self.proms_emb( input )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user