some day I'll get it right

This commit is contained in:
mrq 2023-09-08 15:36:26 -05:00
parent b2907ae7e0
commit f69aad9c65
3 changed files with 7 additions and 9 deletions

View File

@ -163,7 +163,6 @@ class Model:
arch_type: str = "transformer"
training: bool = True
interleave: bool = False
use_multiembedding: bool = True # nasty bandaid I got myself into
frozen_params: list[str] = field(default_factory=lambda: [])
@property

View File

@ -107,7 +107,7 @@ class AR_NAR(Base):
prev_list = resps_list
while True:
level = prev_list[0].shape[-1] - 1
level = prev_list[0].shape[-1]
if level >= self.n_resp_levels:
break

View File

@ -64,13 +64,10 @@ class MultiEmbedding(nn.Embedding):
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
super().__init__(max_n_levels, token_dim)
self.monolithic = monolithic
self.max_n_levels = max_n_levels
self.n_tokens = n_tokens
self.monolithic = monolithic
if self.monolithic:
self.weights = nn.Parameter(torch.randn(2, max_n_levels, n_tokens, token_dim))
else:
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
# I imagine this is an oversight in the NAR.
@ -78,14 +75,16 @@ class MultiEmbedding(nn.Embedding):
if len(x_list) == 0:
return []
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
if self.monolithic:
w = self.weights[0 if quant_levels is None else 1]
w = self.weight[:1] if quant_levels is None else self.weight[1:]
else:
w = self.weight
padded_x_list = []
for xi in x_list:
for i, xi in enumerate(x_list):
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
wi = w.shape[0] - xi.shape[1]
xi = F.pad(xi, (0, 0, 0, wi)) # t l k