tweaked initial NAR pass's initial token embeddings to use a different value, or osmething

This commit is contained in:
mrq 2024-08-03 09:01:37 -05:00
parent 26f74c5739
commit 9e1989be1b

View File

@ -1024,11 +1024,15 @@ class Base(nn.Module):
embedding = self.tones_emb( input ) embedding = self.tones_emb( input )
elif name == "resp": elif name == "resp":
if "len" in self.capabilities and quant_level == 0: if "len" in self.capabilities and quant_level == 0:
"""
# fill with "stop" tokens for NAR-only model # fill with "stop" tokens for NAR-only model
embedding = self.resps_emb( embedding = self.resps_emb(
torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token), torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token),
offset = 0 offset = 0
) )
"""
# fill with filler tokens for NAR-only model
embedding = self.dropout_token.repeat((input.shape[0], 1))
else: else:
# get RVQ level 0, or up to targetted RVQ level inference # get RVQ level 0, or up to targetted RVQ level inference
if self.version <= 4: if self.version <= 4: