tweaked initial NAR pass's initial token embeddings to use a different value, or osmething
This commit is contained in:
parent
26f74c5739
commit
9e1989be1b
|
@ -1024,11 +1024,15 @@ class Base(nn.Module):
|
||||||
embedding = self.tones_emb( input )
|
embedding = self.tones_emb( input )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
if "len" in self.capabilities and quant_level == 0:
|
if "len" in self.capabilities and quant_level == 0:
|
||||||
|
"""
|
||||||
# fill with "stop" tokens for NAR-only model
|
# fill with "stop" tokens for NAR-only model
|
||||||
embedding = self.resps_emb(
|
embedding = self.resps_emb(
|
||||||
torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token),
|
torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token),
|
||||||
offset = 0
|
offset = 0
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
# fill with filler tokens for NAR-only model
|
||||||
|
embedding = self.dropout_token.repeat((input.shape[0], 1))
|
||||||
else:
|
else:
|
||||||
# get RVQ level 0, or up to targetted RVQ level inference
|
# get RVQ level 0, or up to targetted RVQ level inference
|
||||||
if self.version <= 4:
|
if self.version <= 4:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user